diff --git a/Cargo.lock b/Cargo.lock index 6dd207bc1..4c92e0c55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,39 @@ dependencies = [ "winapi", ] +[[package]] +name = "async-stream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "171374e7e3b2504e0e5236e3b59260560f9fe94bfe9ac39ba5e4e929c5590625" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "648ed8c8d2ce5409ccd57453d9d1b214b342a0d69376a6feda1fd6cae3299308" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "async-usercalls" +version = "0.1.0" +source = "git+https://github.com/fortanix/rust-sgx.git?branch=mz/async-usercalls#4cf2a8e12912bfd5e0ce8ef7fcf8f607110dfda2" +dependencies = [ + "crossbeam-channel", + "fnv", + "fortanix-sgx-abi", + "ipc-queue", + "lazy_static", +] + [[package]] name = "atty" version = "0.2.14" @@ -37,6 +70,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b671c8fb71b457dd4ae18c4ba1e59aa81793daacc361d82fcd410cef0d491875" +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + [[package]] name = "base64" version = "0.9.3" @@ -122,6 +161,24 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7c3dd8985a7111efc5c80b44e23ecdd8c007de8ade3b96595387e812b957cf5" +[[package]] +name = "bytes" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" + +[[package]] +name = "bytes" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0dcbc35f504eb6fc275a6d20e4ebcda18cf50d40ba6fabff8c711fa16cb3b16" + +[[package]] +name = "bytes" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" + [[package]] name = "cc" version = "1.0.67" @@ -205,6 +262,27 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "crossbeam-channel" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b153fe7cbef478c567df0f972e02e6d736db11affe43dfc9c56a9374d1adfb87" +dependencies = [ + "crossbeam-utils", + "maybe-uninit", +] + +[[package]] +name = "crossbeam-utils" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8" +dependencies = [ + "autocfg 1.0.1", + "cfg-if 0.1.10", + "lazy_static", +] + [[package]] name = "env_logger" version = "0.8.3" @@ -218,12 +296,117 @@ dependencies = [ "termcolor", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "fortanix-sgx-abi" +version = "0.4.0" +source = "git+https://github.com/fortanix/rust-sgx.git?branch=mz/async-usercalls#4cf2a8e12912bfd5e0ce8ef7fcf8f607110dfda2" + [[package]] name = "fuchsia-cprng" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" +[[package]] +name = "futures" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12aa0eb539080d55c3f2d45a67c3b58b6b0773c1a3ca2dfec66d58c97fd66ca" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" + +[[package]] +name = "futures-executor" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "522de2a0fe3e380f1bc577ba0474108faf3f6b18321dbf60b3b9c39a75073377" + +[[package]] +name = "futures-macro" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e4a4b95cea4b4ccbcf1c5675ca7c4ee4e9e75eb79944d07defde18068f79bb" +dependencies = [ + "autocfg 1.0.1", + "proc-macro-hack", + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "futures-sink" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11" + +[[package]] +name = "futures-task" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99" + +[[package]] +name = "futures-util" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481" +dependencies = [ + "autocfg 1.0.1", + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite 0.2.7", + "pin-utils", + "proc-macro-hack", + "proc-macro-nested", + "slab", +] + [[package]] name = "generic-array" version = "0.12.3" @@ -239,6 +422,32 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" +[[package]] +name = "h2" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e4728fd124914ad25e99e3d15a9361a879f6620f63cb56bbb08f95abb97a535" +dependencies = [ + "bytes 0.5.6", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio 0.2.25", + "tokio-util", + "tracing", + "tracing-futures", +] + +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" + [[package]] name = "hermit-abi" version = "0.1.17" @@ -254,12 +463,39 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "805026a5d0141ffc30abb3be3173848ad46a1b1664fe632428479619a3644d77" +[[package]] +name = "http" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" +dependencies = [ + "bytes 1.1.0", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13d5ff830006f7646652e057693569bfe0d51760c0085a071769d142a205111b" +dependencies = [ + "bytes 0.5.6", + "http", +] + [[package]] name = "httparse" version = "1.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd179ae861f0c2e53da70d892f5f3029f9594be0c41dc5269cd371691b1dc2f9" +[[package]] +name = "httpdate" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "494b4d60369511e7dea41cf646832512a94e542f68bb9c49e54518e0f468eb47" + [[package]] name = "humantime" version = "2.1.0" @@ -285,6 +521,29 @@ dependencies = [ "url", ] +[[package]] +name = "hyper" +version = "0.13.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a6f157065790a3ed2f88679250419b5cdd96e714a0d65f7797fd337186e96bb" +dependencies = [ + "bytes 0.5.6", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project", + "tokio 0.2.25", + "tower-service", + "tracing", + "want", +] + [[package]] name = "idna" version = "0.1.5" @@ -296,6 +555,30 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" +dependencies = [ + "autocfg 1.0.1", + "hashbrown", +] + +[[package]] +name = "ipc-queue" +version = "0.1.0" +source = "git+https://github.com/fortanix/rust-sgx.git?branch=mz/async-usercalls#4cf2a8e12912bfd5e0ce8ef7fcf8f607110dfda2" +dependencies = [ + "fortanix-sgx-abi", +] + +[[package]] +name = "itoa" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" + [[package]] name = "language-tags" version = "0.2.2" @@ -316,9 +599,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.62" +version = "0.2.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34fcd2c08d2f832f376f4173a231990fa5aef4e99fb569867318a227ef4c06ba" +checksum = "a2a5ac8f984bfcf3a823267e5fde638acc3325f6496633a5da6bb6eb2171e103" [[package]] name = "libloading" @@ -366,10 +649,17 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08" +[[package]] +name = "maybe-uninit" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00" + [[package]] name = "mbedtls" version = "0.8.1" dependencies = [ + "async-stream", "bit-vec", "bitflags", "block-modes", @@ -378,8 +668,10 @@ dependencies = [ "cfg-if 1.0.0", "chrono", "core_io", + "futures", "hex", - "hyper", + "hyper 0.10.16", + "hyper 0.13.10", "libc", "matches", "mbedtls-sys-auto", @@ -391,6 +683,9 @@ dependencies = [ "serde_cbor", "serde_derive", "spin", + "tokio 0.2.25", + "tokio 0.3.4", + "tracing", "yasna", ] @@ -424,6 +719,29 @@ dependencies = [ "log 0.3.9", ] +[[package]] +name = "mio" +version = "0.7.6" +source = "git+https://github.com/mzohreva/mio?branch=mz/sgx-port-0.7.6#b4370d8bea9951f7f01e29115b8ca0e9bfa25a77" +dependencies = [ + "async-usercalls", + "crossbeam-channel", + "libc", + "log 0.4.8", + "miow", + "ntapi", + "winapi", +] + +[[package]] +name = "miow" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" +dependencies = [ + "winapi", +] + [[package]] name = "nom" version = "5.1.2" @@ -434,13 +752,22 @@ dependencies = [ "version_check 0.9.2", ] +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi", +] + [[package]] name = "num-bigint" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343b3df15c945a59e72aae31e89a7cfc9e11850e96d4fde6fed5e3c7c8d9c887" dependencies = [ - "autocfg", + "autocfg 0.1.6", "num-integer", "num-traits", ] @@ -451,7 +778,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b85e541ef8255f6cf42bbfe4ef361305c6c135d10919ecc26126c4e5ae94bc09" dependencies = [ - "autocfg", + "autocfg 0.1.6", "num-traits", ] @@ -461,7 +788,7 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ba9a427cfca2be13aa6f6403b0b7e7368fe982bfa16fccc450ce74c46cd9b32" dependencies = [ - "autocfg", + "autocfg 0.1.6", ] [[package]] @@ -498,12 +825,62 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31010dd2e1ac33d5b46a5b413495239882813e0369f8ed8a5e266f173602f831" +[[package]] +name = "pin-project" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "576bc800220cc65dac09e99e97b08b358cfab6e17078de8dc5fee223bd2d0c08" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e8fe8163d14ce7f0cdac2e040116f22eac817edabff0be91e8aff7e9accf389" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "pin-project-lite" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "257b64915a082f7811703966789728173279bdebb956b143dbcd23f6f970a777" + +[[package]] +name = "pin-project-lite" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72d5370d90f49f70bd033c3d75e87fc529fbfff9d6f7cccef07d6170079d91ea" +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + +[[package]] +name = "proc-macro-nested" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" + [[package]] name = "proc-macro2" version = "0.4.30" @@ -613,9 +990,9 @@ checksum = "b5eb417147ba9860a96cfe72a0b93bf88fee1744b5636ec99ab20c1aa9376581" [[package]] name = "rs-libc" -version = "0.2.2" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b434763aff74b924c33af0ce3a3791c7c5ff8fb431773061dde30447e2fb77f0" +checksum = "80a671d6c4696a49b78e0a271c99bc58bc1a17a64893a3684a1ba1a944b26ca9" dependencies = [ "cc", ] @@ -690,6 +1067,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42a568c8f2cd051a4d283bd6eb0343ac214c1b0f1ac19f93e1175b2dee38c73d" +[[package]] +name = "slab" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" + [[package]] name = "spin" version = "0.4.10" @@ -777,12 +1160,121 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +[[package]] +name = "tokio" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6703a273949a90131b290be1fe7b039d0fc884aa1935860dfcbe056f28cd8092" +dependencies = [ + "bytes 0.5.6", + "fnv", + "futures-core", + "memchr", + "pin-project-lite 0.1.12", +] + +[[package]] +name = "tokio" +version = "0.3.4" +source = "git+https://github.com/mzohreva/tokio?branch=mz/sgx-port-0.3.4#8af31a7b14986b34d6d544f48c2423e7b9792c7f" +dependencies = [ + "autocfg 1.0.1", + "bytes 0.6.0", + "lazy_static", + "libc", + "memchr", + "mio", + "num_cpus", + "pin-project-lite 0.2.7", + "slab", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "0.3.1" +source = "git+https://github.com/mzohreva/tokio?branch=mz/sgx-port-0.3.4#8af31a7b14986b34d6d544f48c2423e7b9792c7f" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "tokio-util" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be8242891f2b6cbef26a2d7e8605133c2c554cd35b3e4948ea892d6d68436499" +dependencies = [ + "bytes 0.5.6", + "futures-core", + "futures-sink", + "log 0.4.8", + "pin-project-lite 0.1.12", + "tokio 0.2.25", +] + +[[package]] +name = "tower-service" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "360dfd1d6d30e05fda32ace2c8c70e9c0a9da713275777f5a4dbb8a1893930c6" + +[[package]] +name = "tracing" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2ba9ab62b7d6497a8638dfda5e5c4fb3b2d5a7fca4118f2b96151c8ef1a437e" +dependencies = [ + "cfg-if 1.0.0", + "log 0.4.8", + "pin-project-lite 0.2.7", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98863d0dd09fa59a1b79c6750ad80dbda6b75f4e71c437a6a1a8cb91a8bcbd77" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.9", + "syn 1.0.64", +] + +[[package]] +name = "tracing-core" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46125608c26121c81b0c6d693eab5a420e416da7e43c426d2e8f7df8da8a3acf" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "tracing-futures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" +dependencies = [ + "pin-project", + "tracing", +] + [[package]] name = "traitobject" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efd1f82c56340fdf16f2a953d7bda4f8fdffba13d93b00844c25572110b26079" +[[package]] +name = "try-lock" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" + [[package]] name = "typeable" version = "0.1.2" @@ -875,6 +1367,16 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed" +[[package]] +name = "want" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +dependencies = [ + "log 0.4.8", + "try-lock", +] + [[package]] name = "which" version = "3.0.0" @@ -886,9 +1388,9 @@ dependencies = [ [[package]] name = "winapi" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" dependencies = [ "winapi-i686-pc-windows-gnu", "winapi-x86_64-pc-windows-gnu", diff --git a/Cargo.toml b/Cargo.toml index 755259713..c5e8f20b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,7 @@ [workspace] members = ["mbedtls", "mbedtls-sys"] + +[patch.crates-io] +mio = { git = "https://github.com/mzohreva/mio", branch = "mz/sgx-port-0.7.6" } +tokio = { git = "https://github.com/mzohreva/tokio", branch = "mz/sgx-port-0.3.4" } + diff --git a/mbedtls-sys/Cargo.toml b/mbedtls-sys/Cargo.toml index 14093449d..4419524a9 100644 --- a/mbedtls-sys/Cargo.toml +++ b/mbedtls-sys/Cargo.toml @@ -42,9 +42,8 @@ quote = "1.0.9" # * strstr/strlen/strncpy/strncmp/strcmp/snprintf # * memmove/memcpy/memcmp/memset # * rand/printf (used only for self tests. optionally use custom_printf) -default = ["std", "debug", "threading", "zlib", "time", "aesni", "padlock", "legacy_protocols"] -std = ["debug"] # deprecated automatic enabling of debug, can be removed on major version bump -debug = [] +default = ["std", "threading", "zlib", "time", "aesni", "padlock", "legacy_protocols"] +std = [] # deprecated automatic enabling of debug, can be removed on major version bump custom_printf = [] custom_has_support = [] aes_alt = [] diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 3112a6402..6b0c5f866 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -30,13 +30,14 @@ bit-vec = { version = "0.5", optional = true } block-modes = { version = "0.3", optional = true } rc2 = { version = "0.3", optional = true } cfg-if = "1.0.0" +tokio = { version = "0.3.4", optional = true } [target.x86_64-fortanix-unknown-sgx.dependencies] -rs-libc = "0.2.0" +rs-libc = "0.1.0" chrono = "0.4" [dependencies.mbedtls-sys-auto] -version = "2.25.0" +version = "2.26.0" default-features = false features = ["custom_printf", "trusted_cert_callback", "threading"] path = "../mbedtls-sys" @@ -48,6 +49,11 @@ serde_cbor = "0.6" hex = "0.3" matches = "0.1.8" hyper = { version = "0.10.16", default-features = false } +hyper13 = { package = "hyper", version = "0.13", default-features = false, features = ["stream"] } +tokio-02 = { package = "tokio", version = "0.2", default-features = false } +async-stream = "0.3.0" +futures = "0.3" +tracing = "0.1" [build-dependencies] cc = "1.0" @@ -56,7 +62,6 @@ cc = "1.0" # Features are documented in the README default = ["std", "aesni", "time", "padlock"] std = ["mbedtls-sys-auto/std", "serde/std", "yasna"] -debug = ["mbedtls-sys-auto/debug"] no_std_deps = ["core_io", "spin"] force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"] mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"] @@ -68,6 +73,9 @@ padlock = ["mbedtls-sys-auto/padlock"] dsa = ["std", "yasna", "num-bigint", "bit-vec"] pkcs12 = ["std", "yasna"] pkcs12_rc2 = ["pkcs12", "rc2", "block-modes"] +async = ["std", "tokio","tokio/net","tokio/io-util", "tokio/macros"] +async-rt = ["async", "tokio/rt", "tokio/sync", "tokio/rt-multi-thread"] +migration_mode=[] [[example]] name = "client" @@ -92,3 +100,20 @@ required-features = ["std"] [[test]] name = "hyper" required-features = ["std"] + +[[test]] +name = "hyper13" +required-features = ["std", "async-rt"] + +[[test]] +name = "async_session" +path = "tests/async_session.rs" +required-features = ["async-rt"] + + +[package.metadata.fortanix-sgx] +threads = 100 +heap-size = 0x40000000 +stack-size = 0x100000 +# The following are not processed by the EDP tools but are picked up by build-enclave.sh: +#isvprodid = 66 diff --git a/mbedtls/src/lib.rs b/mbedtls/src/lib.rs index a98f6f40b..b4ae55966 100644 --- a/mbedtls/src/lib.rs +++ b/mbedtls/src/lib.rs @@ -53,9 +53,11 @@ mod private; // needs to be pub for global visiblity #[doc(hidden)] -#[cfg(sys_threading_component = "custom")] + +#[cfg(all(sys_threading_component = "custom", not(feature = "migration_mode")))] pub mod threading; +#[cfg(not(feature = "migration_mode"))] cfg_if::cfg_if! { if #[cfg(any(feature = "force_aesni_support", target_env = "sgx"))] { // needs to be pub for global visiblity @@ -105,6 +107,7 @@ mod alloc_prelude { pub(crate) use rust_alloc::borrow::Cow; } +#[cfg(not(feature = "migration_mode"))] cfg_if::cfg_if! { if #[cfg(sys_time_component = "custom")] { use mbedtls_sys::types::{time_t, tm}; @@ -154,7 +157,7 @@ cfg_if::cfg_if! { /// /// The caller must ensure no other MbedTLS code is running when calling this /// function. -#[cfg(feature = "debug")] +#[cfg(all(feature = "debug", not(feature = "migration_mode")))] pub unsafe fn set_global_debug_threshold(threshold: i32) { mbedtls_sys::debug_set_threshold(threshold); } diff --git a/mbedtls/src/pk/dsa/mod.rs b/mbedtls/src/pk/dsa/mod.rs index fdf030149..bf868217b 100644 --- a/mbedtls/src/pk/dsa/mod.rs +++ b/mbedtls/src/pk/dsa/mod.rs @@ -217,9 +217,13 @@ fn sample_secret_value(upper_bound: &Mpi, rng: &mut F) -> Result Ok(c) } -fn encode_dsa_signature(r: &Mpi, s: &Mpi) -> Result> { - let r = BigUint::from_bytes_be(&r.to_binary()?); - let s = BigUint::from_bytes_be(&s.to_binary()?); +pub fn encode_dsa_signature(r: &Mpi, s: &Mpi) -> Result> { + serialize_signature(&r.to_binary()?, &s.to_binary()?) +} + +pub fn serialize_signature(r: &[u8], s: &[u8]) -> Result> { + let r = BigUint::from_bytes_be(r); + let s = BigUint::from_bytes_be(s); Ok(yasna::construct_der(|w| { w.write_sequence(|w| { @@ -229,6 +233,18 @@ fn encode_dsa_signature(r: &Mpi, s: &Mpi) -> Result> { })) } +pub fn deserialize_signature(signature: &Vec) -> Result<(Vec, Vec)> { + let (r,s) = yasna::parse_der(signature, |r| { + r.read_sequence(|rdr| { + let r = rdr.next().read_biguint()?; + let s = rdr.next().read_biguint()?; + Ok((r,s)) + }) + }).map_err(|_| Error::X509InvalidSignature)?; + + Ok((r.to_bytes_be(), s.to_bytes_be())) +} + impl DsaPrivateKey { pub fn from_components(params: DsaParams, x: Mpi) -> Result { if x <= Mpi::new(1)? || x >= params.q { diff --git a/mbedtls/src/pk/mod.rs b/mbedtls/src/pk/mod.rs index 1d42c6ca8..e4656c97c 100644 --- a/mbedtls/src/pk/mod.rs +++ b/mbedtls/src/pk/mod.rs @@ -201,34 +201,7 @@ define!( // // - Only used when creating/freeing - which is safe by design - eckey_alloc_wrap / eckey_free_wrap // -// 3. ECDSA: mbedtls_ecdsa_info at ../../../mbedtls-sys/vendor/crypto/library/pk_wrap.c:729 -// This does not use internal locks but avoids interior mutability. -// -// - Const access / copies context to stack based variables: -// ecdsa_verify_wrap: ../../../mbedtls-sys/vendor/crypto/library/pk_wrap.c:544 -// This copies the public key on the stack - in buf[] and copies the group id and nbits. -// That is done via: mbedtls_pk_write_pubkey( &p, buf, &key ) where key.pk_ctx = ctx; -// And the key is a const parameter to mbedtls_pk_write_pubkey - ../../../mbedtls-sys/vendor/crypto/library/pkwrite.c:158 -// -// - Const access with additional notes due to call stacks involved. -// -// ecdsa_sign_wrap: ../../../mbedtls-sys/vendor/crypto/library/pk_wrap.c:657 -// mbedtls_ecdsa_write_signature ../../../mbedtls-sys/vendor/crypto/library/ecdsa.c:688 -// mbedtls_ecdsa_write_signature_restartable ../../../mbedtls-sys/vendor/crypto/library/ecdsa.c:640 -// MBEDTLS_ECDSA_DETERMINISTIC is not defined. -// MBEDTLS_ECDSA_SIGN_ALT is not defined. -// Passes grp to: ecdsa_sign_restartable: ../../../mbedtls-sys/vendor/crypto/library/ecdsa.c:253 -// Const access to group - reads parameters, passed as const to mbedtls_ecp_gen_privkey, -// mbedtls_ecp_mul_restartable: ../../../mbedtls-sys/vendor/crypto/library/ecp.c:2351 -// MBEDTLS_ECP_INTERNAL_ALT is not defined. (otherwise it might not be safe depending on ecp_init/ecp_free) ../../../mbedtls-sys/build/config.rs:131 -// Passes as const to: mbedtls_ecp_check_privkey / mbedtls_ecp_check_pubkey / mbedtls_ecp_get_type( grp -// -// - Ignored due to not defined: ecdsa_verify_rs_wrap, ecdsa_sign_rs_wrap, ecdsa_rs_alloc, ecdsa_rs_free -// (Undefined - MBEDTLS_ECP_RESTARTABLE - ../../../mbedtls-sys/build/config.rs:173) -// -// - Only const access to context: eckey_check_pair -// -// - Only used when creating/freeing - which is safe by design: ecdsa_alloc_wrap, ecdsa_free_wrap +// 3. ECDSA - code uses mbedtls_pk wrappers. In this case code goes through ECKEY logic above. (mbedtls_pk_parse_key intentionally never calls mbedtls_pk_info_from_type with MBEDTLS_PK_ECDSA) // unsafe impl Sync for Pk {} @@ -826,7 +799,7 @@ impl Pk { /// /// On success, returns the actual number of bytes written to `sig`. pub fn sign( - &mut self, + &self, md: MdType, hash: &[u8], sig: &mut [u8], @@ -848,7 +821,7 @@ impl Pk { let mut ret = 0usize; unsafe { pk_sign( - &mut self.inner, + &self.inner as *const _ as *mut _, md.into(), hash.as_ptr(), hash.len(), @@ -912,10 +885,14 @@ impl Pk { } } - pub fn verify(&mut self, md: MdType, hash: &[u8], sig: &[u8]) -> Result<()> { + pub fn verify(&self, md: MdType, hash: &[u8], sig: &[u8]) -> Result<()> { + if hash.len() == 0 || sig.len() == 0 { + return Err(Error::PkBadInputData) + } + unsafe { pk_verify( - &mut self.inner, + &self.inner as *const _ as *mut _, md.into(), hash.as_ptr(), hash.len(), @@ -1240,7 +1217,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi #[test] fn rsa_sign_verify_pkcs1v15() { - let mut pk = + let pk = Pk::generate_rsa(&mut crate::test_support::rand::test_rng(), 2048, 0x10001).unwrap(); let data = b"SIGNATURE TEST SIGNATURE TEST SI"; let mut signature = vec![0u8; (pk.len() + 7) / 8]; diff --git a/mbedtls/src/rng/ctr_drbg.rs b/mbedtls/src/rng/ctr_drbg.rs index bf55a615d..533fec1e0 100644 --- a/mbedtls/src/rng/ctr_drbg.rs +++ b/mbedtls/src/rng/ctr_drbg.rs @@ -17,7 +17,12 @@ use mbedtls_sys::types::size_t; #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; use crate::error::{IntoResult, Result}; -use crate::rng::{EntropyCallback, RngCallback, RngCallbackMut}; +use crate::rng::{EntropyCallback, EntropyCallbackMut, RngCallback, RngCallbackMut}; + +enum EntropyHolder { + Shared(Arc), + Unique(Box), +} define!( // `ctr_drbg_context` inlines an `aes_context`, which is immovable. See @@ -30,7 +35,7 @@ define!( #[c_box_ty(ctr_drbg_context)] #[repr(C)] struct CtrDrbg { - entropy: Arc, + entropy: EntropyHolder, }; const drop: fn(&mut Self) = ctr_drbg_free; impl<'a> Into {} @@ -63,8 +68,28 @@ impl CtrDrbg { ).into_result()?; } - Ok(CtrDrbg { inner, entropy }) + Ok(CtrDrbg { inner, entropy: EntropyHolder::Shared(entropy) }) + } + + pub fn new_mut(entropy: T, additional_entropy: Option<&[u8]>) -> Result { + let mut inner = Box::new(ctr_drbg_context::default()); + + // We take sole ownership of entropy, all access is guarded via mutexes. + let mut entropy = Box::new(entropy); + unsafe { + ctr_drbg_init(&mut *inner); + ctr_drbg_seed( + &mut *inner, + Some(T::call_mut), + entropy.data_ptr_mut(), + additional_entropy.map(<[_]>::as_ptr).unwrap_or(::core::ptr::null()), + additional_entropy.map(<[_]>::len).unwrap_or(0) + ).into_result()?; + } + + Ok(CtrDrbg { inner, entropy: EntropyHolder::Unique(entropy) }) } + pub fn prediction_resistance(&self) -> bool { if self.inner.prediction_resistance == CTR_DRBG_PR_OFF { diff --git a/mbedtls/src/rust_printf.c b/mbedtls/src/rust_printf.c index c3b2ac93c..183552e0d 100644 --- a/mbedtls/src/rust_printf.c +++ b/mbedtls/src/rust_printf.c @@ -9,7 +9,7 @@ #include #include -extern void mbedtls_log(const char* msg); +extern void mbedtls8_log(const char* msg); extern int mbedtls_printf(const char *fmt, ...) { va_list ap; @@ -31,7 +31,7 @@ extern int mbedtls_printf(const char *fmt, ...) { if (n<0) return -1; - mbedtls_log(p); + mbedtls8_log(p); return n; } diff --git a/mbedtls/src/self_test.rs b/mbedtls/src/self_test.rs index 2dde8c3d9..648258964 100644 --- a/mbedtls/src/self_test.rs +++ b/mbedtls/src/self_test.rs @@ -25,7 +25,7 @@ cfg_if::cfg_if! { // needs to be pub for global visiblity #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn mbedtls_log(msg: *const std::os::raw::c_char) { + pub unsafe extern "C" fn mbedtls8_log(msg: *const std::os::raw::c_char) { print!("{}", std::ffi::CStr::from_ptr(msg).to_string_lossy()); } } else { @@ -35,11 +35,13 @@ cfg_if::cfg_if! { // needs to be pub for global visiblity #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn mbedtls_log(msg: *const c_char) { + pub unsafe extern "C" fn mbedtls8_log(msg: *const c_char) { log_f.expect("Called self-test log without enabling self-test")(msg) } } } + +#[cfg(not(feature = "migration_mode"))] cfg_if::cfg_if! { if #[cfg(any(not(feature = "std"), target_env = "sgx"))] { #[allow(non_upper_case_globals)] @@ -66,6 +68,7 @@ cfg_if::cfg_if! { /// The caller needs to ensure this function is not called while any other /// function in this module is called. #[allow(unused)] +#[cfg(not(feature = "migration_mode"))] pub unsafe fn enable(rand: fn() -> c_int, log: Option) { #[cfg(any(not(feature = "std"), target_env = "sgx"))] { rand_f = Some(rand); @@ -79,6 +82,7 @@ pub unsafe fn enable(rand: fn() -> c_int, log: Option) /// /// The caller needs to ensure this function is not called while any other /// function in this module is called. +#[cfg(not(feature = "migration_mode"))] pub unsafe fn disable() { #[cfg(any(not(feature = "std"), target_env = "sgx"))] { rand_f = None; diff --git a/mbedtls/src/ssl/async_utils.rs b/mbedtls/src/ssl/async_utils.rs new file mode 100644 index 000000000..e876a3e20 --- /dev/null +++ b/mbedtls/src/ssl/async_utils.rs @@ -0,0 +1,145 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +#![cfg(all(feature = "std", feature = "async"))] + +use std::cell::Cell; +use std::ptr::null_mut; +use std::rc::Rc; +use std::task::{Context as TaskContext, Poll}; + + +#[cfg(not(feature = "std"))] +use core_io::{Error as IoError, Result as IoResult, ErrorKind as IoErrorKind}; +#[cfg(feature = "std")] +use std::io::{Error as IoError, Result as IoResult, ErrorKind as IoErrorKind}; + + +#[derive(Clone)] +pub struct ErasedContext(Rc>); + +unsafe impl Send for ErasedContext {} + +impl ErasedContext { + pub fn new() -> Self { + Self(Rc::new(Cell::new(null_mut()))) + } + + pub unsafe fn get(&self) -> Option<&mut TaskContext<'_>> { + let ptr = self.0.get(); + if ptr.is_null() { + None + } else { + Some(&mut *(ptr as *mut _)) + } + } + + pub fn set(&self, cx: &mut TaskContext<'_>) { + self.0.set(cx as *mut _ as *mut ()); + } + + pub fn clear(&self) { + self.0.set(null_mut()); + } +} + +// mbedtls_ssl_write() has some weird semantics w.r.t non-blocking I/O: +// +// > When this function returns MBEDTLS_ERR_SSL_WANT_WRITE/READ, it must be +// > called later **with the same arguments**, until it returns a value greater +// > than or equal to 0. When the function returns MBEDTLS_ERR_SSL_WANT_WRITE +// > there may be some partial data in the output buffer, however this is not +// > yet sent. +// +// WriteTracker is used to ensure we pass the same data in that scenario. +// +// Reference: +// https://tls.mbed.org/api/ssl_8h.html#a5bbda87d484de82df730758b475f32e5 +pub struct WriteTracker { + pending: Option>, +} + +struct DigestAndLen { + #[cfg(debug_assertions)] + digest: [u8; 20], // SHA-1 + len: usize, +} + +impl WriteTracker { + fn new() -> Self { + WriteTracker { + pending: None, + } + } + + #[cfg(debug_assertions)] + fn digest(buf: &[u8]) -> [u8; 20] { + use crate::hash::{Md, Type}; + let mut out = [0u8; 20]; + let res = Md::hash(Type::Sha1, buf, &mut out[..]); + assert_eq!(res, Ok(out.len())); + out + } + + pub fn adjust_buf<'a>(&self, buf: &'a [u8]) -> IoResult<&'a [u8]> { + match self.pending.as_ref() { + None => Ok(buf), + Some(pending) => { + if pending.len <= buf.len() { + let buf = &buf[..pending.len]; + + // We only do this check in debug mode since it's an expensive check. + #[cfg(debug_assertions)] + if Self::digest(buf) == pending.digest { + return Ok(buf); + } + + #[cfg(not(debug_assertions))] + return Ok(buf); + } + Err(IoError::new( + IoErrorKind::Other, + "mbedtls expects the same data if the previous call to poll_write() returned Poll::Pending" + )) + }, + } + } + + pub fn post_write(&mut self, buf: &[u8], res: &Poll>) { + match res { + &Poll::Pending => { + if self.pending.is_none() { + self.pending = Some(Box::new(DigestAndLen { + #[cfg(debug_assertions)] + digest: Self::digest(buf), + len: buf.len(), + })); + } + }, + _ => { + self.pending = None; + } + } + } +} + +pub struct IoAdapter { + pub inner: S, + pub ecx: ErasedContext, + pub write_tracker: WriteTracker, +} + +impl IoAdapter { + pub fn new(stream: S) -> Self { + Self { + inner: stream, + ecx: ErasedContext::new(), + write_tracker: WriteTracker::new(), + } + } +} diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 7f2c5debc..b52450575 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -31,7 +31,7 @@ use crate::ssl::ticket::TicketCallback; use crate::x509::Certificate; use crate::x509::Crl; use crate::x509::Profile; -use crate::x509::VerifyError; +use crate::x509::certificate::{VerifyCallback, verify_callback}; #[allow(non_camel_case_types)] #[derive(Eq, PartialEq, PartialOrd, Ord, Debug, Copy, Clone)] @@ -98,12 +98,54 @@ define!( } ); -callback!(VerifyCallback: Fn(&Certificate, i32, &mut VerifyError) -> Result<()>); #[cfg(feature = "std")] callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ()); callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>); callback!(CaCallback: Fn(&MbedtlsList) -> Result>); + +#[repr(transparent)] +pub struct NullTerminatedStrList { + c: Box<[*mut i8]>, +} + +unsafe impl Send for NullTerminatedStrList {} +unsafe impl Sync for NullTerminatedStrList {} + +impl NullTerminatedStrList { + pub fn new(list: &[&str]) -> Result { + let mut c = Vec::with_capacity(list.len() + 1); + + for s in list { + let cstr = ::std::ffi::CString::new(*s).map_err(|_| Error::SslBadInputData)?; + c.push(cstr.into_raw()); + } + + c.push(core::ptr::null_mut()); + + Ok(NullTerminatedStrList { + c: c.into_boxed_slice(), + }) + } + + pub fn as_ptr(&self) -> *const *const u8 { + self.c.as_ptr() as *const _ + } +} + +impl Drop for NullTerminatedStrList { + fn drop(&mut self) { + for i in self.c.iter() { + unsafe { + if !(*i).is_null() { + let _ = ::std::ffi::CString::from_raw(*i); + } + } + } + } +} + + define!( #[c_ty(ssl_config)] #[repr(C)] @@ -120,9 +162,7 @@ define!( ciphersuites: Vec>>, curves: Option>>, - - #[allow(dead_code)] - dhm: Option>, + protocols: Option>, verify_callback: Option>, #[cfg(feature = "std")] @@ -158,7 +198,7 @@ impl Config { rng: None, ciphersuites: vec![], curves: None, - dhm: None, + protocols: None, verify_callback: None, #[cfg(feature = "std")] dbg_callback: None, @@ -188,6 +228,20 @@ impl Config { self.ciphersuites.push(list); } + /// Set the supported Application Layer Protocols. + /// + /// Each protocol name in the list must also be terminated with a null character (`\0`). + pub fn set_alpn_protocols(&mut self, protocols: Arc) -> Result<()> { + unsafe { + ssl_conf_alpn_protocols(&mut self.inner, protocols.as_ptr() as *mut _) + .into_result() + .map(|_| ())?; + } + + self.protocols = Some(protocols); + Ok(()) + } + pub fn set_ciphersuites_for_version(&mut self, list: Arc>, major: c_int, minor: c_int) { Self::check_c_list(&list); unsafe { ssl_conf_ciphersuites_for_version(self.into(), list.as_ptr(), major, minor) } @@ -236,13 +290,13 @@ impl Config { /// Takes both DER and PEM forms of FFDH parameters in `DHParams` format. /// /// When calling on PEM-encoded data, `params` must be NULL-terminated - pub fn set_dh_params(&mut self, dhm: Arc) -> Result<()> { + pub fn set_dh_params(&mut self, dhm: &Dhm) -> Result<()> { unsafe { + // This copies the dhm parameters and does not store any pointer to it ssl_conf_dh_param_ctx(self.into(), dhm.inner_ffi_mut()) .into_result() .map(|_| ())?; } - self.dhm = Some(dhm); Ok(()) } @@ -320,12 +374,10 @@ impl Config { // - We can pointer cast to it to allow storing additional objects. // let cb = &mut *(closure as *mut F); - let context = UnsafeFrom::from(ctx).unwrap(); - - let mut ctx = HandshakeContext::init(context); + let ctx = UnsafeFrom::from(ctx).unwrap(); let name = from_raw_parts(name, name_len); - match cb(&mut ctx, name) { + match cb(ctx, name) { Ok(()) => 0, Err(_) => -1, } @@ -343,38 +395,6 @@ impl Config { where F: VerifyCallback + 'static, { - unsafe extern "C" fn verify_callback( - closure: *mut c_void, - crt: *mut x509_crt, - depth: c_int, - flags: *mut u32, - ) -> c_int - where - F: VerifyCallback + 'static, - { - if crt.is_null() || closure.is_null() || flags.is_null() { - return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA; - } - - let cb = &mut *(closure as *mut F); - let crt: &mut Certificate = UnsafeFrom::from(crt).expect("valid certificate"); - - let mut verify_error = match VerifyError::from_bits(*flags) { - Some(ve) => ve, - // This can only happen if mbedtls is setting flags in VerifyError that are - // missing from our definition. - None => return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA, - }; - - let res = cb(crt, depth, &mut verify_error); - *flags = verify_error.bits(); - match res { - Ok(()) => 0, - Err(e) => e.to_int(), - } - } - - self.verify_callback = Some(Arc::new(cb)); unsafe { ssl_conf_verify(self.into(), Some(verify_callback::), &**self.verify_callback.as_mut().unwrap() as *const _ as *mut c_void) } } diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index f40a31f15..2fa037698 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -6,31 +6,45 @@ * option. This file may not be copied, modified, or distributed except * according to those terms. */ - -use core::any::Any; use core::result::Result as StdResult; -#[cfg(not(feature = "std"))] -use core_io::{Read, Write, Result as IoResult}; -#[cfg(feature = "std")] -use std::io::{Read, Write, Result as IoResult}; + #[cfg(feature = "std")] -use std::sync::Arc; +use { + std::io::{Read, Write, Result as IoResult}, + std::sync::Arc, +}; + +#[cfg(not(feature = "std"))] +use core_io::{Read, Write, Result as IoResult, ErrorKind as IoErrorKind}; + +#[cfg(all(feature = "std", feature = "async"))] +use { + std::io::{Error as IoError, ErrorKind as IoErrorKind}, + std::marker::Unpin, + std::pin::Pin, + std::task::{Context as TaskContext, Poll}, +}; + use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; use mbedtls_sys::types::size_t; use mbedtls_sys::*; -use crate::alloc::{List as MbedtlsList}; +#[cfg(all(feature = "std", feature = "async"))] +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + #[cfg(not(feature = "std"))] use crate::alloc_prelude::*; +use crate::alloc::{List as MbedtlsList}; use crate::error::{Error, Result, IntoResult}; use crate::pk::Pk; use crate::private::UnsafeFrom; use crate::ssl::config::{Config, Version, AuthMode}; +#[cfg(all(feature = "std", feature = "async"))] +use crate::ssl::async_utils::IoAdapter; use crate::x509::{Certificate, Crl, VerifyError}; - -pub trait IoCallback : Any { +pub trait IoCallback { unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized; unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized; fn data_ptr(&mut self) -> *mut c_void; @@ -70,13 +84,7 @@ impl IoCallback for IO { define!( #[c_ty(ssl_context)] #[repr(C)] - struct Context { - // config is used read-only for mutliple contexts and is immutable once configured. - config: Arc, - - // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated. - io: Option>, - + struct HandshakeContext { handshake_ca_cert: Option>>, handshake_crl: Option>, @@ -89,7 +97,27 @@ define!( impl<'a> UnsafeFrom {} ); -impl Context { +define!( + #[c_custom_ty(ssl_context)] + #[repr(C)] + struct Context { + // Base structure used in SNI callback where we cannot determine the io type. + inner: HandshakeContext, + + // config is used read-only for mutliple contexts and is immutable once configured. + config: Arc, + + // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated. + io: Option>, + }; + impl<'a> Into {} +); + +#[cfg(all(feature = "std", feature = "async"))] +pub type AsyncContext = Context>; + + +impl Context { pub fn new(config: Arc) -> Self { let mut inner = ssl_context::default(); @@ -99,19 +127,22 @@ impl Context { }; Context { - inner, + inner: HandshakeContext { + inner, + handshake_ca_cert: None, + handshake_crl: None, + + handshake_cert: vec![], + handshake_pk: vec![], + }, config: config.clone(), io: None, - - handshake_ca_cert: None, - handshake_crl: None, - - handshake_cert: vec![], - handshake_pk: vec![], } } +} - pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> { +impl Context { + pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> { unsafe { let mut io = Box::new(io); ssl_session_reset(self.into()).into_result()?; @@ -127,21 +158,35 @@ impl Context { ); self.io = Some(io); + self.inner.reset_handshake(); + } - self.handshake_cert.clear(); - self.handshake_pk.clear(); - self.handshake_ca_cert = None; - self.handshake_crl = None; - - match ssl_handshake(self.into()).into_result() { - Err(e) => { - // safely end borrow of io - ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); - self.io = None; - Err(e) - }, - Ok(_) => { - Ok(()) + self.handshake() + } +} + +impl Context { + pub fn handshake(&mut self) -> Result<()> { + match unsafe { ssl_flush_output(self.into()).into_result() } { + Err(Error::SslWantRead) => Err(Error::SslWantRead), + Err(Error::SslWantWrite) => Err(Error::SslWantWrite), + Err(e) => { + unsafe { ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Err(e) + }, + Ok(_) => { + match unsafe { ssl_handshake(self.into()).into_result() } { + Err(Error::SslWantRead) => Err(Error::SslWantRead), + Err(Error::SslWantWrite) => Err(Error::SslWantWrite), + Err(e) => { + unsafe { ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Err(e) + }, + Ok(_) => { + Ok(()) + } } } } @@ -187,22 +232,21 @@ impl Context { self.io = None; } } - - pub fn io(&self) -> Option<&dyn Any> { + pub fn io(&self) -> Option<&T> { self.io.as_ref().map(|v| &**v) } - pub fn io_mut(&mut self) -> Option<&mut dyn Any> { + pub fn io_mut(&mut self) -> Option<&mut T> { self.io.as_mut().map(|v| &mut **v) } /// Return the minor number of the negotiated TLS version pub fn minor_version(&self) -> i32 { - self.inner.minor_ver + self.handle().minor_ver } /// Return the major number of the negotiated TLS version pub fn major_version(&self) -> i32 { - self.inner.major_ver + self.handle().major_ver } /// Return the number of bytes currently available to read that @@ -231,27 +275,41 @@ impl Context { /// All assigned ciphersuites are listed by the IANA in /// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt pub fn ciphersuite(&self) -> Result { - if self.inner.session.is_null() { + if self.handle().session.is_null() { return Err(Error::SslBadInputData); } - Ok(unsafe { self.inner.session.as_ref().unwrap().ciphersuite as u16 }) + Ok(unsafe { self.handle().session.as_ref().unwrap().ciphersuite as u16 }) } pub fn peer_cert(&self) -> Result>> { - if self.inner.session.is_null() { + if self.handle().session.is_null() { return Err(Error::SslBadInputData); } unsafe { // We cannot call the peer cert function as we need a pointer to a pointer to create the MbedtlsList, we need something in the heap / cannot use any local variable for that. - let peer_cert : &MbedtlsList = UnsafeFrom::from(&((*self.inner.session).peer_cert) as *const *mut x509_crt as *const *const x509_crt).ok_or(Error::SslBadInputData)?; + let peer_cert : &MbedtlsList = UnsafeFrom::from(&((*self.handle().session).peer_cert) as *const *mut x509_crt as *const *const x509_crt).ok_or(Error::SslBadInputData)?; Ok(Some(peer_cert)) } } + + + #[cfg(feature = "std")] + pub fn get_alpn_protocol(&self) -> Result> { + unsafe { + let ptr = ssl_get_alpn_protocol(self.handle()); + if ptr.is_null() { + Ok(None) + } else { + let s = std::ffi::CStr::from_ptr(ptr).to_str()?; + Ok(Some(s)) + } + } + } } -impl Drop for Context { +impl Drop for Context { fn drop(&mut self) { unsafe { self.close(); @@ -260,7 +318,7 @@ impl Drop for Context { } } -impl Read for Context { +impl Read for Context { fn read(&mut self, buf: &mut [u8]) -> IoResult { match unsafe { ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result() } { Err(Error::SslPeerCloseNotify) => Ok(0), @@ -270,7 +328,7 @@ impl Read for Context { } } -impl Write for Context { +impl Write for Context { fn write(&mut self, buf: &[u8]) -> IoResult { match unsafe { ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result() } { Err(Error::SslPeerCloseNotify) => Ok(0), @@ -283,12 +341,6 @@ impl Write for Context { Ok(()) } } - - -pub struct HandshakeContext<'ctx> { - pub context: &'ctx mut Context, -} - // // Class exists only during SNI callback that is configured from Config. // SNI Callback must provide input whos lifetime exceed the SNI closure to avoid memory corruptions. @@ -301,42 +353,44 @@ pub struct HandshakeContext<'ctx> { // - mbedtls not providing any callbacks on handshake finish. // - no reasonable way to obtain a storage within the sni callback tied to the handshake or to the rust Context. (without resorting to a unscalable map or pointer magic that mbedtls may invalidate) // -impl<'ctx> HandshakeContext<'ctx> { - - pub(crate) fn init(context: &'ctx mut Context) -> Self { - HandshakeContext { context } +impl HandshakeContext { + pub fn reset_handshake(&mut self) { + self.handshake_cert.clear(); + self.handshake_pk.clear(); + self.handshake_ca_cert = None; + self.handshake_crl = None; } pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> { - if self.context.inner.handshake as *const _ == ::core::ptr::null() { + if self.inner.handshake as *const _ == ::core::ptr::null() { return Err(Error::SslBadInputData); } - unsafe { ssl_set_hs_authmode(self.context.into(), am as i32) } + unsafe { ssl_set_hs_authmode(self.into(), am as i32) } Ok(()) } pub fn set_ca_list( &mut self, - chain: Arc>, + chain: Option>>, crl: Option>, ) -> Result<()> { // mbedtls_ssl_set_hs_ca_chain does not check for NULL handshake. - if self.context.inner.handshake as *const _ == ::core::ptr::null() { + if self.inner.handshake as *const _ == ::core::ptr::null() { return Err(Error::SslBadInputData); } // This will override current handshake CA chain. unsafe { ssl_set_hs_ca_chain( - self.context.into(), - chain.inner_ffi_mut(), + self.into(), + chain.as_ref().map(|chain| chain.inner_ffi_mut()).unwrap_or(::core::ptr::null_mut()), crl.as_ref().map(|crl| crl.inner_ffi_mut()).unwrap_or(::core::ptr::null_mut()), ); } - self.context.handshake_ca_cert = Some(chain); - self.context.handshake_crl = crl; + self.handshake_ca_cert = chain; + self.handshake_crl = crl; Ok(()) } @@ -350,21 +404,245 @@ impl<'ctx> HandshakeContext<'ctx> { key: Arc, ) -> Result<()> { // mbedtls_ssl_set_hs_own_cert does not check for NULL handshake. - if self.context.inner.handshake as *const _ == ::core::ptr::null() { + if self.inner.handshake as *const _ == ::core::ptr::null() { return Err(Error::SslBadInputData); } // This will append provided certificate pointers in internal structures. unsafe { - ssl_set_hs_own_cert(self.context.into(), chain.inner_ffi_mut(), key.inner_ffi_mut()).into_result()?; + ssl_set_hs_own_cert(self.into(), chain.inner_ffi_mut(), key.inner_ffi_mut()).into_result()?; } - self.context.handshake_cert.push(chain); - self.context.handshake_pk.push(key); + self.handshake_cert.push(chain); + self.handshake_pk.push(key); Ok(()) } } +#[cfg(all(feature = "std", feature = "async"))] +pub trait IoAsyncCallback { + unsafe extern "C" fn call_recv_async(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized; + unsafe extern "C" fn call_send_async(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized; +} + +#[cfg(all(feature = "std", feature = "async"))] +impl IoAsyncCallback for IoAdapter { + unsafe extern "C" fn call_recv_async(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + + let adapter = &mut *(user_data as *mut IoAdapter); + + if let Some(cx) = adapter.ecx.get() { + let mut buf = ReadBuf::new(::core::slice::from_raw_parts_mut(data, len)); + let stream = Pin::new(&mut adapter.inner); + + match stream.poll_read(cx, &mut buf) { + Poll::Ready(Ok(())) => buf.filled().len() as c_int, + Poll::Ready(Err(_)) => ::mbedtls_sys::ERR_NET_RECV_FAILED, + Poll::Pending => ::mbedtls_sys::ERR_SSL_WANT_READ, + } + } else { + ::mbedtls_sys::ERR_NET_RECV_FAILED + } + } + + unsafe extern "C" fn call_send_async(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + + let adapter = &mut *(user_data as *mut IoAdapter); + + if let Some(cx) = adapter.ecx.get() { + let stream = Pin::new(&mut adapter.inner); + + match stream.poll_write(cx, ::core::slice::from_raw_parts(data, len)) { + Poll::Ready(Ok(i)) => i as c_int, + Poll::Ready(Err(_)) => ::mbedtls_sys::ERR_NET_RECV_FAILED, + Poll::Pending => ::mbedtls_sys::ERR_SSL_WANT_WRITE, + } + } else { + ::mbedtls_sys::ERR_NET_RECV_FAILED + } + } +} + +#[cfg(all(feature = "std", feature = "async"))] +struct HandshakeFuture<'a, T>(&'a mut Context::>); + +#[cfg(all(feature = "std", feature = "async"))] +impl std::future::Future for HandshakeFuture<'_, T> { + type Output = Result<()>; + fn poll(mut self: Pin<&mut Self>, ctx: &mut TaskContext) -> std::task::Poll { + self.0.io_mut().ok_or(Error::NetInvalidContext)? + .ecx.set(ctx); + + let result = match self.0.handshake() { + Err(Error::SslWantRead) | + Err(Error::SslWantWrite) => { + Poll::Pending + }, + Err(e) => Poll::Ready(Err(e)), + Ok(()) => Poll::Ready(Ok(())) + }; + + self.0.io_mut().map(|v| v.ecx.clear()); + + result + } +} + +#[cfg(all(feature = "std", feature = "async"))] +impl AsyncContext { + pub async fn accept_async(config: Arc, io: T, hostname: Option<&str>) -> IoResult> { + let mut context = Self::new(config); + context.establish_async(io, hostname).await.map_err(|e| crate::private::error_to_io_error(e))?; + Ok(context) + } + + pub async fn establish_async(&mut self, io: T, hostname: Option<&str>) -> Result<()> { + unsafe { + let mut io = Box::new(IoAdapter::new(io)); + + ssl_session_reset(self.into()).into_result()?; + self.set_hostname(hostname)?; + + let ptr = &mut *io as *mut _ as *mut c_void; + ssl_set_bio( + self.into(), + ptr, + Some(IoAdapter::::call_send_async), + Some(IoAdapter::::call_recv_async), + None, + ); + + self.io = Some(io); + self.inner.reset_handshake(); + } + + HandshakeFuture(self).await + } +} + +#[cfg(all(feature = "std", feature = "async"))] +impl AsyncRead for Context> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + + if self.handle().session.is_null() { + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "stream has been shutdown"))); + } + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.set(cx); + + let result = match unsafe { ssl_read((&mut *self).into(), buf.initialize_unfilled().as_mut_ptr(), buf.initialize_unfilled().len()).into_result() } { + Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(())), + Err(Error::SslWantRead) => Poll::Pending, + Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), + Ok(i) => { + buf.advance(i as usize); + Poll::Ready(Ok(())) + } + }; + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.clear(); + + result + } +} + +#[cfg(all(feature = "std", feature = "async"))] +impl AsyncWrite for Context> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &[u8], + ) -> Poll> { + + if self.handle().session.is_null() { + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "stream has been shutdown"))); + } + + let buf = { + let io = self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))?; + io.ecx.set(cx); + io.write_tracker.adjust_buf(buf) + }?; + + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.set(cx); + + let result = match unsafe { ssl_write((&mut *self).into(), buf.as_ptr(), buf.len()).into_result() } { + Err(Error::SslPeerCloseNotify) => Poll::Ready(Ok(0)), + Err(Error::SslWantWrite) => Poll::Pending, + Err(e) => Poll::Ready(Err(crate::private::error_to_io_error(e))), + Ok(i) => Poll::Ready(Ok(i as usize)) + }; + + let io = self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))?; + + io.ecx.clear(); + io.write_tracker.post_write(buf, &result); + + cx.waker().clone().wake(); + + result + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + // We can only flush the actual IO here. + // To flush mbedtls we need writes with the same buffer until complete. + let io = &mut self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .inner; + let stream = Pin::new(io); + stream.poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + if self.handle().session.is_null() { + return Poll::Ready(Err(IoError::new(IoErrorKind::Other, "stream has been shutdown"))); + } + + self.io_mut().ok_or(IoError::new(IoErrorKind::Other, "stream has been shutdown"))? + .ecx.set(cx); + + let result = match unsafe { ssl_close_notify((&mut *self).into()).into_result() } { + Err(Error::SslWantRead) | + Err(Error::SslWantWrite) => Poll::Pending, + Err(e) => { + unsafe { ssl_set_bio((&mut *self).into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Poll::Ready(Err(crate::private::error_to_io_error(e))) + } + Ok(0) => { + unsafe { ssl_set_bio((&mut *self).into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Poll::Ready(Ok(())) + } + Ok(v) => { + unsafe { ssl_set_bio((&mut *self).into(), ::core::ptr::null_mut(), None, None, None); } + self.io = None; + Poll::Ready(Err(IoError::new(IoErrorKind::Other, format!("unexpected result from ssl_close_notify: {}", v)))) + } + }; + + self.io_mut().map(|v| v.ecx.clear()); + result + } +} + // ssl_get_alpn_protocol // ssl_get_max_frag_len diff --git a/mbedtls/src/ssl/mod.rs b/mbedtls/src/ssl/mod.rs index 430d439ea..196e824e0 100644 --- a/mbedtls/src/ssl/mod.rs +++ b/mbedtls/src/ssl/mod.rs @@ -10,6 +10,7 @@ pub mod ciphersuites; pub mod config; pub mod context; pub mod ticket; +pub mod async_utils; #[doc(inline)] pub use self::ciphersuites::CipherSuite; @@ -19,3 +20,6 @@ pub use self::config::{Config, Version, UseSessionTickets}; pub use self::context::Context; #[doc(inline)] pub use self::ticket::TicketContext; +#[cfg(all(feature = "std", feature = "async"))] +#[doc(inline)] +pub use self::context::{AsyncContext}; diff --git a/mbedtls/src/wrapper_macros.rs b/mbedtls/src/wrapper_macros.rs index 6c2d20156..d08ed0721 100644 --- a/mbedtls/src/wrapper_macros.rs +++ b/mbedtls/src/wrapper_macros.rs @@ -61,6 +61,10 @@ macro_rules! define { define_struct!(define $(#[$m])* struct $name $(lifetime $l)* inner $inner members $($($(#[$mm])* $member: $member_type,)*)*); define_struct!(<< $name $(lifetime $l)* inner $inner >> $($defs)*); }; + { #[c_custom_ty($inner:ident)] $(#[$m:meta])* struct $name:ident$(<$l:tt>)* $({ $($(#[$mm:meta])* $member:ident: $member_type:ty,)* })?; $($defs:tt)* } => { + define_struct!(define_custom $(#[$m])* struct $name $(lifetime $l)* inner $inner members $($($(#[$mm])* $member: $member_type,)*)*); + define_struct!(<< $name $(lifetime $l)* inner $inner >> $($defs)*); + }; // Do not use UnsafeFrom with 'c_box_ty'. That is currently not supported as its not needed anywhere, support may be added in the future if needed anywhere. { #[c_box_ty($inner:ident)] $(#[$m:meta])* struct $name:ident$(<$l:tt>)* $({ $($(#[$mm:meta])* $member:ident: $member_type:ty,)* })?; $($defs:tt)* } => { define_struct!(define_box $(#[$m])* struct $name $(lifetime $l)* inner $inner members $($($(#[$mm])* $member: $member_type,)*)*); @@ -109,6 +113,32 @@ macro_rules! define_enum { } macro_rules! define_struct { + { define_custom $(#[$m:meta])* struct $name:ident $(lifetime $l:tt)* inner $inner:ident members $($(#[$mm:meta])* $member:ident: $member_type:ty,)* } => { + as_item!( + #[allow(dead_code)] + $(#[$m])* + pub struct $name<$($l)*> { + $($(#[$mm])* $member: $member_type,)* + } + ); + + as_item!( + #[allow(dead_code)] + impl<$($l)*> $name<$($l)*> { + pub(crate) fn handle(&self) -> &::mbedtls_sys::$inner { + self.inner.handle() + } + + pub(crate) fn handle_mut(&mut self) -> &mut ::mbedtls_sys::$inner { + self.inner.handle_mut() + } + } + ); + + as_item!( + unsafe impl<$($l)*> Send for $name<$($l)*> {} + ); + }; { define $(#[$m:meta])* struct $name:ident $(lifetime $l:tt)* inner $inner:ident members $($(#[$mm:meta])* $member:ident: $member_type:ty,)* } => { as_item!( #[allow(dead_code)] diff --git a/mbedtls/src/x509/certificate.rs b/mbedtls/src/x509/certificate.rs index f5e7db533..64180e507 100644 --- a/mbedtls/src/x509/certificate.rs +++ b/mbedtls/src/x509/certificate.rs @@ -11,7 +11,7 @@ use core::iter::FromIterator; use core::ptr::NonNull; use mbedtls_sys::*; -use mbedtls_sys::types::raw_types::c_char; +use mbedtls_sys::types::raw_types::*; use crate::alloc::{List as MbedtlsList, Box as MbedtlsBox}; #[cfg(not(feature = "std"))] @@ -22,6 +22,7 @@ use crate::pk::Pk; use crate::private::UnsafeFrom; use crate::rng::Random; use crate::x509::Time; +use crate::x509::VerifyError; extern "C" { pub(crate) fn forward_mbedtls_calloc(n: mbedtls_sys::types::size_t, size: mbedtls_sys::types::size_t) -> *mut mbedtls_sys::types::raw_types::c_void; @@ -253,6 +254,77 @@ impl Certificate { } result.map(|_| ()) } + + pub fn verify_single( + cert: &MbedtlsBox, + ca: &MbedtlsBox, + err_info: Option<&mut String>, + ) -> Result<()> { + let mut flags = 0; + let result = unsafe { + x509_crt_verify( + cert.inner_ffi_mut(), + ca.inner_ffi_mut(), + ::core::ptr::null_mut(), + ::core::ptr::null(), + &mut flags, + None, + ::core::ptr::null_mut(), + ) + } + .into_result(); + + if result.is_err() { + if let Some(err_info) = err_info { + let verify_info = crate::private::alloc_string_repeat(|buf, size| unsafe { + let prefix = "\0"; + x509_crt_verify_info(buf, size, prefix.as_ptr() as *const _, flags) + }); + if let Ok(error_str) = verify_info { + *err_info = error_str; + } + } + } + result.map(|_| ()) + } + + pub fn verify_callback( + chain: &MbedtlsList, + trust_ca: &MbedtlsList, + err_info: Option<&mut String>, + cb: F, + ) -> Result<()> + where + F: VerifyCallback + 'static, + { + let mut flags = 0; + let result = unsafe { + x509_crt_verify( + chain.inner_ffi_mut(), + trust_ca.inner_ffi_mut(), + ::core::ptr::null_mut(), + ::core::ptr::null(), + &mut flags, + Some(verify_callback::), + &cb as *const _ as *mut c_void, + ) + } + .into_result(); + + if result.is_err() { + if let Some(err_info) = err_info { + let verify_info = crate::private::alloc_string_repeat(|buf, size| unsafe { + let prefix = "\0"; + x509_crt_verify_info(buf, size, prefix.as_ptr() as *const _, flags) + }); + if let Ok(error_str) = verify_info { + *err_info = error_str; + } + } + } + result.map(|_| ()) + } + } // TODO @@ -719,6 +791,39 @@ impl Extend> for MbedtlsList { } } +pub(crate) unsafe extern "C" fn verify_callback( + closure: *mut c_void, + crt: *mut x509_crt, + depth: c_int, + flags: *mut u32, +) -> c_int +where + F: VerifyCallback + 'static, +{ + if crt.is_null() || closure.is_null() || flags.is_null() { + return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA; + } + + let cb = &mut *(closure as *mut F); + let crt: &mut Certificate = UnsafeFrom::from(crt).expect("valid certificate"); + + let mut verify_error = match VerifyError::from_bits(*flags) { + Some(ve) => ve, + // This can only happen if mbedtls is setting flags in VerifyError that are + // missing from our definition. + None => return ::mbedtls_sys::ERR_X509_BAD_INPUT_DATA, + }; + + let res = cb(crt, depth, &mut verify_error); + *flags = verify_error.bits(); + match res { + Ok(()) => 0, + Err(e) => e.to_int(), + } +} + +callback!(VerifyCallback: Fn(&Certificate, i32, &mut VerifyError) -> Result<()>); + #[cfg(test)] mod tests { @@ -995,7 +1100,21 @@ cYp0bH/RcPTC0Z+ZaqSWMtfxRrk63MJQF9EXpDCdvQRcTMD9D85DJrMKn8aumq0M // try again after fixing the chain chain.push(c_int2.clone()); - Certificate::verify(&chain, &mut c_root, None).unwrap(); + + + let mut err_str = String::new(); + + let verify_callback = |_crt: &Certificate, _depth: i32, verify_flags: &mut VerifyError| { + verify_flags.remove(VerifyError::CERT_EXPIRED); + Ok(()) + }; + + let res = Certificate::verify_callback(&chain, &mut c_root, Some(&mut err_str), verify_callback); + + match res { + Ok(()) => (), + Err(e) => assert!(false, "Failed to verify, error: {}, err_str: {}", e, err_str), + }; } { @@ -1004,7 +1123,18 @@ cYp0bH/RcPTC0Z+ZaqSWMtfxRrk63MJQF9EXpDCdvQRcTMD9D85DJrMKn8aumq0M chain.push(c_int1.clone()); chain.push(c_int2.clone()); - Certificate::verify(&chain, &mut c_root, None).unwrap(); + let verify_callback = |_crt: &Certificate, _depth: i32, verify_flags: &mut VerifyError| { + verify_flags.remove(VerifyError::CERT_EXPIRED); + Ok(()) + }; + + let mut err_str = String::new(); + let res = Certificate::verify_callback(&chain, &mut c_root, Some(&mut err_str), verify_callback); + + match res { + Ok(()) => (), + Err(e) => assert!(false, "Failed to verify, error: {}, err_str: {}", e, err_str), + }; } } diff --git a/mbedtls/src/x509/mod.rs b/mbedtls/src/x509/mod.rs index 56d626337..bb08072c6 100644 --- a/mbedtls/src/x509/mod.rs +++ b/mbedtls/src/x509/mod.rs @@ -28,7 +28,6 @@ pub use self::profile::Profile; use mbedtls_sys::*; use mbedtls_sys::types::raw_types::c_uint; bitflags! { - #[doc(inline)] pub struct KeyUsage: c_uint { const DIGITAL_SIGNATURE = X509_KU_DIGITAL_SIGNATURE as c_uint; const NON_REPUDIATION = X509_KU_NON_REPUDIATION as c_uint; @@ -43,7 +42,6 @@ bitflags! { } bitflags! { - #[doc(inline)] pub struct VerifyError: u32 { const CERT_BAD_KEY = X509_BADCERT_BAD_KEY as u32; const CERT_BAD_MD = X509_BADCERT_BAD_MD as u32; diff --git a/mbedtls/tests/async_session.rs b/mbedtls/tests/async_session.rs new file mode 100644 index 000000000..b5442b8c3 --- /dev/null +++ b/mbedtls/tests/async_session.rs @@ -0,0 +1,295 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +#![cfg(not(target_env = "sgx"))] + +extern crate mbedtls; + +use std::sync::Arc; +use std::pin::Pin; +use std::future::Future; + +use mbedtls::pk::Pk; +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context, Version}; +use mbedtls::x509::{Certificate, VerifyError}; +use mbedtls::Error; +use mbedtls::Result as TlsResult; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +mod support; +use support::entropy::entropy_new; +use support::keys; + +use mbedtls::ssl::async_utils::IoAdapter; + +async fn client( + conn: TcpStream, + min_version: Version, + max_version: Version, + exp_version: Option) -> TlsResult<()> { + + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); + let expected_flags = VerifyError::empty(); + #[cfg(feature = "time")] + let expected_flags = expected_flags | VerifyError::CERT_EXPIRED; + { + let verify_callback = move |crt: &Certificate, depth: i32, verify_flags: &mut VerifyError| { + + match (crt.subject().unwrap().as_str(), depth, &verify_flags) { + ("CN=RootCA", 1, _) => (), + (keys::EXPIRED_CERT_SUBJECT, 0, flags) => assert_eq!(**flags, expected_flags), + _ => assert!(false), + }; + + verify_flags.remove(VerifyError::CERT_EXPIRED); //we check the flags at the end, + //so removing this flag here prevents the connections from failing with VerifyError + Ok(()) + }; + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_verify_callback(verify_callback); + config.set_ca_list(cacert, None); + config.set_min_version(min_version)?; + config.set_max_version(max_version)?; + let mut ctx = Context::new(Arc::new(config)); + + match ctx.establish_async(conn, None).await { + Ok(()) => { + assert_eq!(ctx.version(), exp_version.unwrap()); + } + Err(e) => { + match e { + Error::SslBadHsProtocolVersion => {assert!(exp_version.is_none())}, + Error::SslFatalAlertMessage => {}, + e => panic!("Unexpected error {}", e), + }; + return Ok(()); + } + }; + + let ciphersuite = ctx.ciphersuite().unwrap(); + ctx + .write_all(format!("Client2Server {:4x}", ciphersuite).as_bytes()) + .await + .unwrap(); + let mut buf = [0u8; 13 + 4 + 1]; + ctx.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, format!("Server2Client {:4x}", ciphersuite).as_bytes()); + } // drop verify_callback, releasing borrow of verify_args + Ok(()) +} + +async fn server( + conn: TcpStream, + min_version: Version, + max_version: Version, + exp_version: Option, +) -> TlsResult<()> { + let entropy = entropy_new(); + let rng = Arc::new(CtrDrbg::new(Arc::new(entropy), None)?); + let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes())?); + let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None)?); + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_min_version(min_version)?; + config.set_max_version(max_version)?; + config.push_cert(cert, key)?; + let mut ctx = Context::new(Arc::new(config)); + + match ctx.establish_async(conn, None).await { + Ok(()) => { + assert_eq!(ctx.version(), exp_version.unwrap()); + } + Err(e) => { + match e { + // client just closes connection instead of sending alert + Error::NetSendFailed => {assert!(exp_version.is_none())}, + Error::SslBadHsProtocolVersion => {}, + e => panic!("Unexpected error {}", e), + }; + return Ok(()); + } + }; + + //assert_eq!(ctx.get_alpn_protocol().unwrap().unwrap(), None); + let ciphersuite = ctx.ciphersuite().unwrap(); + ctx + .write_all(format!("Server2Client {:4x}", ciphersuite).as_bytes()) + .await + .unwrap(); + let mut buf = [0u8; 13 + 1 + 4]; + ctx.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, format!("Client2Server {:4x}", ciphersuite).as_bytes()); + Ok(()) +} + +async fn with_client(conn: TcpStream, f: F) -> R +where + F: FnOnce(Context>) -> Pin + Send>>, +{ + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None).unwrap()); + let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes()).unwrap()); + + let verify_callback = move |_crt: &Certificate, _depth: i32, verify_flags: &mut VerifyError| { + verify_flags.remove(VerifyError::CERT_EXPIRED); + Ok(()) + }; + + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_verify_callback(verify_callback); + config.set_ca_list(cacert, None); + + let mut ctx = Context::new(Arc::new(config)); + ctx.establish_async(conn, None).await.unwrap(); + + f(ctx).await +} + +async fn with_server(conn: TcpStream, f: F) -> R +where + F: FnOnce(Context>) -> Pin + Send>>, +{ + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None).unwrap()); + let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes()).unwrap()); + let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None).unwrap()); + + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.push_cert(cert, key).unwrap(); + let mut ctx = Context::new(Arc::new(config)); + + ctx.establish_async(conn, None).await.unwrap(); + + f(ctx).await +} + +#[cfg(unix)] +mod test { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + #[tokio::test] + async fn asyncsession_client_server_test() { + use mbedtls::ssl::Version; + + #[derive(Copy,Clone)] + struct TestConfig { + min_c: Version, + max_c: Version, + min_s: Version, + max_s: Version, + exp_ver: Option, + } + + impl TestConfig { + pub fn new(min_c: Version, max_c: Version, min_s: Version, max_s: Version, exp_ver: Option) -> Self { + TestConfig { min_c, max_c, min_s, max_s, exp_ver } + } + } + + let test_configs = [ + TestConfig::new(Version::Ssl3, Version::Ssl3, Version::Ssl3, Version::Ssl3, Some(Version::Ssl3)), + TestConfig::new(Version::Ssl3, Version::Tls1_2, Version::Ssl3, Version::Ssl3, Some(Version::Ssl3)), + TestConfig::new(Version::Tls1_0, Version::Tls1_0, Version::Tls1_0, Version::Tls1_0, Some(Version::Tls1_0)), + TestConfig::new(Version::Tls1_1, Version::Tls1_1, Version::Tls1_1, Version::Tls1_1, Some(Version::Tls1_1)), + TestConfig::new(Version::Tls1_2, Version::Tls1_2, Version::Tls1_2, Version::Tls1_2, Some(Version::Tls1_2)), + TestConfig::new(Version::Tls1_0, Version::Tls1_2, Version::Tls1_0, Version::Tls1_2, Some(Version::Tls1_2)), + TestConfig::new(Version::Tls1_2, Version::Tls1_2, Version::Tls1_0, Version::Tls1_2, Some(Version::Tls1_2)), + TestConfig::new(Version::Tls1_0, Version::Tls1_1, Version::Tls1_2, Version::Tls1_2, None) + ]; + + for config in &test_configs { + let min_c = config.min_c; + let max_c = config.max_c; + let min_s = config.min_s; + let max_s = config.max_s; + let exp_ver = config.exp_ver; + + if (max_c < Version::Tls1_2 || max_s < Version::Tls1_2) && !cfg!(feature = "legacy_protocols") { + continue; + } + + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + let c = tokio::spawn(super::client(c, min_c, max_c, exp_ver.clone())); + let s = tokio::spawn(super::server(s, min_s, max_s, exp_ver)); + + c.await.unwrap().unwrap(); + s.await.unwrap().unwrap(); + } + } + + #[tokio::test] + async fn asyncsession_shutdown1() { + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + + let c = tokio::spawn(super::with_client(c, |mut session| Box::pin(async move { + session.shutdown().await.unwrap(); + }))); + + let s = tokio::spawn(super::with_server(s, |mut session| Box::pin(async move { + let mut buf = [0u8; 1]; + match session.read(&mut buf).await { + Ok(0) | Err(_) => {} + _ => panic!("expected no data"), + } + }))); + + c.await.unwrap(); + s.await.unwrap(); + } + + #[tokio::test] + async fn asyncsession_shutdown2() { + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + + let c = tokio::spawn(super::with_client(c, |mut session| Box::pin(async move { + let mut buf = [0u8; 5]; + session.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"hello"); + match session.read(&mut buf).await { + Ok(0) | Err(_) => {} + _ => panic!("expected no data"), + } + }))); + + let s = tokio::spawn(super::with_server(s, |mut session| Box::pin(async move { + session.write_all(b"hello").await.unwrap(); + session.shutdown().await.unwrap(); + }))); + + c.await.unwrap(); + s.await.unwrap(); + } + + #[tokio::test] + async fn asyncsession_shutdown3() { + let (c, s) = crate::support::net::create_tcp_pair_async().unwrap(); + + let c = tokio::spawn(super::with_client(c, |mut session| Box::pin(async move { + session.shutdown().await + }))); + + let s = tokio::spawn(super::with_server(s, |mut session| Box::pin(async move { + session.shutdown().await + }))); + + match (c.await.unwrap(), s.await.unwrap()) { + (Err(_), Err(_)) => panic!("at least one should succeed"), + _ => {} + } + } +} diff --git a/mbedtls/tests/ec.rs b/mbedtls/tests/ec.rs index 2ecdc25c4..aea532209 100644 --- a/mbedtls/tests/ec.rs +++ b/mbedtls/tests/ec.rs @@ -44,7 +44,7 @@ wvkbR/h/+CNU1mMPdGoooNsldBtbNKgoAIsirMI/kk+q+9TTP4HqZpVt/qor/fz1 #[test] fn sign_verify() { - let mut k = Pk::from_private_key(TEST_KEY_PEM.as_bytes(), None).unwrap(); + let k = Pk::from_private_key(TEST_KEY_PEM.as_bytes(), None).unwrap(); let data = b"SIGNATURE TEST SIGNATURE TEST SI"; let mut signature1 = [0u8; ECDSA_MAX_LEN]; @@ -67,7 +67,7 @@ fn sign_verify() { #[test] fn verify_failure() { - let mut k = Pk::from_private_key(TEST_KEY_PEM.as_bytes(), None).unwrap(); + let k = Pk::from_private_key(TEST_KEY_PEM.as_bytes(), None).unwrap(); let data = b"SIGNATURE TEST SIGNATURE TEST SI"; let mut signature = [0u8; ECDSA_MAX_LEN]; @@ -150,7 +150,7 @@ fn sign_verify_rfc6979_sig() { #[test] fn buffer_too_small() { - let mut k = Pk::from_private_key(TEST_KEY_PEM.as_bytes(), None).unwrap(); + let k = Pk::from_private_key(TEST_KEY_PEM.as_bytes(), None).unwrap(); let data = b"SIGNATURE TEST SIGNATURE TEST SI"; let mut signature = [0u8; ECDSA_MAX_LEN - 1]; diff --git a/mbedtls/tests/hyper.rs b/mbedtls/tests/hyper.rs index 07a890ba7..f753582b9 100644 --- a/mbedtls/tests/hyper.rs +++ b/mbedtls/tests/hyper.rs @@ -12,12 +12,12 @@ use mbedtls::ssl::{Config, Context}; // Native TLS compatibility - to move to native tls client in the future #[derive(Clone)] pub struct TlsStream { - context: Arc>, + context: Arc>>, phantom: PhantomData, } impl TlsStream { - pub fn new(context: Arc>) -> Self { + pub fn new(context: Arc>>) -> Self { TlsStream { context: context, phantom: PhantomData, @@ -28,14 +28,14 @@ impl TlsStream { unsafe impl Send for TlsStream {} unsafe impl Sync for TlsStream {} -impl io::Read for TlsStream +impl io::Read for TlsStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.context.lock().unwrap().read(buf) } } -impl io::Write for TlsStream +impl io::Write for TlsStream { fn write(&mut self, buf: &[u8]) -> io::Result { self.context.lock().unwrap().write(buf) @@ -52,19 +52,19 @@ impl NetworkStream for TlsStream fn peer_addr(&mut self) -> io::Result { self.context.lock().unwrap().io_mut() .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))? - .downcast_mut::().unwrap().peer_addr() + .peer_addr() } fn set_read_timeout(&self, dur: Option) -> io::Result<()> { self.context.lock().unwrap().io_mut() .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))? - .downcast_mut::().unwrap().set_read_timeout(dur) + .set_read_timeout(dur) } fn set_write_timeout(&self, dur: Option) -> io::Result<()> { self.context.lock().unwrap().io_mut() .ok_or(IoError::new(IoErrorKind::NotFound, "No peer available"))? - .downcast_mut::().unwrap().set_write_timeout(dur) + .set_write_timeout(dur) } } diff --git a/mbedtls/tests/hyper13.rs b/mbedtls/tests/hyper13.rs new file mode 100644 index 000000000..b1b2448f7 --- /dev/null +++ b/mbedtls/tests/hyper13.rs @@ -0,0 +1,322 @@ +#![allow(unused_imports)] + + +use async_stream::stream; + +use std::fmt; +use std::future::Future; +use std::io; +use std::io::{Error as IoError}; +use std::pin::Pin; +use std::sync::{Arc}; +use std::task::{Context as TaskContext, Poll}; +use std::net::SocketAddr; + +use hyper13::Server; +use hyper13::service::{make_service_fn, service_fn}; +use hyper13::client::connect::{Connected, Connection}; +use hyper13::{Client, service::Service, Uri, Request, Body, Method, Response, StatusCode}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::{TcpStream, TcpListener}; +use tokio_02::io::{AsyncRead as AsyncRead02, AsyncWrite as AsyncWrite02}; + +use mbedtls::ssl::async_utils::IoAdapter; +use mbedtls::ssl::{Config, AsyncContext}; + +use futures::stream::{FuturesUnordered}; + +#[derive(Clone)] +pub struct HttpsConnector { + config: Arc, +} + +#[derive(Debug)] +struct ForceHttpsButUriNotHttps; + +impl std::error::Error for ForceHttpsButUriNotHttps {} + +impl fmt::Display for ForceHttpsButUriNotHttps { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("https required but URI was not https") + } +} + +const DEFAULT_HTTPS_PORT: u16 = 443; + +impl Service for HttpsConnector { + type Response = IoCompat>; + type Error = Box; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut TaskContext<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, dst: Uri) -> Self::Future { + // Strip [] for IPv6 addresses + let host = dst.host().unwrap_or("").trim_matches(|c| c == '[' || c == ']').to_owned(); + let port = dst.port_u16().unwrap_or(DEFAULT_HTTPS_PORT); + let config = self.config.clone(); + + Box::pin(async move { + if dst.scheme_str() != Some("https") { + return Err(ForceHttpsButUriNotHttps.into()); + } + + let tcp = TcpStream::connect((host.clone(), port)).await?; + let mut tls = AsyncContext::new(config); + tls.establish_async(tcp, Some(&host)).await?; + Ok(IoCompat(tls)) + }) + } +} + +// IoCompat is needed because hyper 0.13 relies on tokio 0.2's `AsyncRead` +// and `AsyncWrite` traits. It would have been nice if we could use +// `tokio_compat_02::IoCompat`, but that type does not implement `Connection` +// and we cannot impl `Connection` for it here either since it's not defined +// in this crate. +pub struct IoCompat(T); + +impl Connection for IoCompat { + fn connected(&self) -> Connected { + let connected = Connected::new(); + //check_alpn(&self.0, connected) + connected + } +} + +impl AsyncRead02 for IoCompat { + fn poll_read(self: Pin<&mut Self>, cx: &mut TaskContext, buf: &mut [u8]) -> Poll> { + let mut read_buf = ReadBuf::new(buf); + match Pin::new(&mut self.get_mut().0).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } + } +} + +impl AsyncWrite02 for IoCompat { + fn poll_write(self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_write(cx, buf) + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_flush(cx) + } + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_shutdown(cx) + } +} + +#[derive(Copy, Clone)] +pub struct TokioExecutor; + +impl hyper13::rt::Executor for TokioExecutor +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, fut: F) { + tokio::spawn(fut); + } +} + + + +use tokio_02::stream::Stream; + +type TlsFuture = Pin>, IoError>> + Send>>; + +pub struct HyperAcceptor { + clients: FuturesUnordered>>, io::Error>>>, + listener: TcpListener, + config: Arc, +} + +impl HyperAcceptor { + pub async fn create(config: Arc, addr: &str) -> Result { + let listener = TcpListener::bind(addr).await?; + + Ok(HyperAcceptor { + clients: FuturesUnordered::new(), + listener, + config, + }) + } +} + +const MAX_CONCURRENT_ACCEPTS: usize = 100; + +impl hyper13::server::accept::Accept for HyperAcceptor { + type Conn = IoCompat>; + type Error = io::Error; + + fn poll_accept(mut self: Pin<&mut Self>, cx: &mut TaskContext,) -> Poll>> { + if self.clients.len() < MAX_CONCURRENT_ACCEPTS { + match self.listener.poll_accept(cx) { + Poll::Pending => (), + Poll::Ready(Ok((conn, _addr))) => { + let config = self.config.clone(); + self.clients.push(tokio::spawn(async move { + let context = AsyncContext::accept_async(config, conn, None).await?; + Ok(IoCompat(context)) + })); + }, + Poll::Ready(Err(e)) => { + // We likely don't care about user errors enough to stop processing under normal circumstances + return Poll::Ready(Some(Err(e))); + }, + }; + } + + if self.clients.len() > 0 { + match Pin::new(&mut self.clients).poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(v)) => Poll::Ready(Some(v?)), // fold Result Poll::Ready(None), + } + } else { + Poll::Pending + } + } +} + + + +#[cfg(test)] +mod tests { + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + use mbedtls::pk::Pk; + use mbedtls::ssl::Config; + use mbedtls::ssl::config::{Endpoint, Preset, Transport, AuthMode, Version, UseSessionTickets, Renegotiation}; + use mbedtls::ssl::context::HandshakeContext; + use mbedtls::x509::{Certificate, VerifyError}; + use std::sync::Arc; + use mbedtls::ssl::CipherSuite::*; + use std::io::Write; + use mbedtls::ssl::TicketContext; + use std::time::Instant; + + #[cfg(not(target_env = "sgx"))] + use mbedtls::rng::{OsEntropy, CtrDrbg, HmacDrbg}; + + #[cfg(target_env = "sgx")] + use mbedtls::rng::{Rdrand}; + + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_02::stream::StreamExt; + use futures::stream::{FuturesUnordered}; + + #[cfg(not(target_env = "sgx"))] + pub fn rng_new() -> Arc { + let entropy = Arc::new(OsEntropy::new()); + let rng = Arc::new(CtrDrbg::new(entropy, None).unwrap()); + rng + } + + #[cfg(target_env = "sgx")] + pub fn rng_new() -> Arc { + Arc::new(Rdrand) + } + + pub const PEM_KEY: &'static [u8] = concat!(include_str!("./support/keys/user.key"),"\0").as_bytes(); + pub const PEM_CERT: &'static [u8] = concat!(include_str!("./support/keys/user.crt"),"\0").as_bytes(); + pub const ROOT_CA_CERT: &'static [u8] = concat!(include_str!("./support/keys/ca.crt"),"\0").as_bytes(); + + #[tokio::test] + async fn async_hyper_client_test() { + + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_authmode(AuthMode::None); + config.set_rng(rng_new()); + config.set_min_version(Version::Tls1_2).unwrap(); + + let https = HttpsConnector { config: Arc::new(config) }; + let client = Client::builder().executor(TokioExecutor).build::<_, hyper13::Body>(https); + + let res = client.get("https://hyper.rs".parse().unwrap()).await.unwrap(); + assert_eq!(res.status(), 200); + } + + async fn echo(req: Request) -> Result, hyper::Error> { + let mut response = Response::new(Body::empty()); + + match (req.method(), req.uri().path()) { + (&Method::GET, "/") => *response.body_mut() = Body::from("Try POST /echo\n"), + (&Method::POST, "/echo") => *response.body_mut() = req.into_body(), + _ => *response.status_mut() = StatusCode::NOT_FOUND, + }; + + Ok(response) + } + + async fn get_acceptor(address: &str) -> Result { + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + + config.set_rng(rng_new()); + config.set_authmode(AuthMode::None); + config.set_min_version(Version::Tls1_2).unwrap(); + + let cert = Arc::new(Certificate::from_pem_multiple(PEM_CERT).unwrap()); + let key = Arc::new(Pk::from_private_key(PEM_KEY, None).unwrap()); + config.push_cert(cert, key).unwrap(); + + HyperAcceptor::create(Arc::new(config), address).await + } + + #[tokio::test] + async fn async_hyper_server_fullhandshake_test() { + std::env::set_var("RUST_BACKTRACE", "full"); + + // Set up hyper server to echo function and a graceful shutdown + let acceptor = get_acceptor("127.0.0.1:0").await.unwrap(); + let local_addr = acceptor.listener.local_addr().unwrap().clone(); + + let service = make_service_fn(|_| async { Ok::<_, io::Error>(service_fn(echo)) }); + let server = Server::builder(acceptor).executor(TokioExecutor).serve(service); + + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + let graceful = server.with_graceful_shutdown(async { rx.await.ok(); }); + + let s = tokio::spawn(graceful); + + let mut clients = FuturesUnordered::new(); + + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_authmode(AuthMode::None); + config.set_rng(rng_new()); + config.set_min_version(Version::Tls1_2).unwrap(); + let config = Arc::new(config); + + let start = Instant::now(); + + for _ in 0..100 { + let config = config.clone(); + + clients.push(tokio::spawn(async move { + let client = Client::builder().executor(TokioExecutor).build::<_, hyper13::Body>(HttpsConnector { config }); + + let mut res = client.get(format!("https://{}/", local_addr).parse().unwrap()).await.unwrap(); + assert_eq!(res.status(), 200); + + let body_bytes = hyper13::body::to_bytes(res.into_body()).await.unwrap(); + let body = String::from_utf8(body_bytes.to_vec()).expect("response was not valid utf-8"); + assert_eq!(body, "Try POST /echo\n"); + })); + + if clients.len() > MAX_CONCURRENT_ACCEPTS { + clients.next().await.unwrap(); + } + } + + while let Some(r) = clients.next().await { + r.unwrap(); + } + + tx.send(()); + s.await.unwrap().unwrap(); + } +} diff --git a/mbedtls/tests/rsa.rs b/mbedtls/tests/rsa.rs index 33c3d68d2..84815191d 100644 --- a/mbedtls/tests/rsa.rs +++ b/mbedtls/tests/rsa.rs @@ -21,7 +21,7 @@ const EXPONENT: u32 = 0x10001; #[test] fn sign_verify() { - let mut k = Pk::generate_rsa(&mut test_rng(), RSA_BITS, EXPONENT).unwrap(); + let k = Pk::generate_rsa(&mut test_rng(), RSA_BITS, EXPONENT).unwrap(); let data = b"SIGNATURE TEST SIGNATURE TEST SI"; let mut signature = [0u8; RSA_BITS as usize / 8]; @@ -36,7 +36,7 @@ fn sign_verify() { #[test] fn buffer_too_small() { - let mut k = Pk::generate_rsa(&mut test_rng(), RSA_BITS, EXPONENT).unwrap(); + let k = Pk::generate_rsa(&mut test_rng(), RSA_BITS, EXPONENT).unwrap(); let data = b"SIGNATURE TEST SIGNATURE TEST SI"; let mut signature = [0u8; RSA_BITS as usize / 8 - 1]; diff --git a/mbedtls/tests/ssl_conf_ca_cb.rs b/mbedtls/tests/ssl_conf_ca_cb.rs index 880a699bc..370400333 100644 --- a/mbedtls/tests/ssl_conf_ca_cb.rs +++ b/mbedtls/tests/ssl_conf_ca_cb.rs @@ -25,8 +25,6 @@ use mbedtls::ssl::config::CaCallback; mod support; use support::entropy::entropy_new; -use mbedtls::alloc::{List as MbedtlsList}; - fn client(conn: TcpStream, ca_callback: F) -> TlsResult<()> where F: CaCallback + Send + 'static, @@ -62,7 +60,8 @@ mod test { use crate::support::keys; use mbedtls::x509::{Certificate}; use mbedtls::Error; - + use mbedtls::alloc::{List as MbedtlsList, Box as MbedtlsBox}; + // This callback should accept any valid self-signed certificate fn self_signed_ca_callback(child: &MbedtlsList) -> TlsResult> { Ok(child.clone()) diff --git a/mbedtls/tests/support/net.rs b/mbedtls/tests/support/net.rs index f061b32af..9446d7f7f 100644 --- a/mbedtls/tests/support/net.rs +++ b/mbedtls/tests/support/net.rs @@ -26,3 +26,14 @@ pub fn create_tcp_pair() -> IoResult<(TcpStream, TcpStream)> { } } } + +#[cfg(feature = "tokio")] +pub fn create_tcp_pair_async() -> IoResult<(tokio::net::TcpStream, tokio::net::TcpStream)> { + let (c, s) = create_tcp_pair()?; + c.set_nonblocking(true)?; + s.set_nonblocking(true)?; + Ok(( + tokio::net::TcpStream::from_std(c)?, + tokio::net::TcpStream::from_std(s)?, + )) +}