From 811e3df882a52ce7b3b4289c62e826d0144ea0d1 Mon Sep 17 00:00:00 2001 From: WashingtonKK Date: Wed, 4 Dec 2024 16:24:41 +0300 Subject: [PATCH] add tests Signed-off-by: WashingtonKK --- internal/server/grpc/grpc_test.go | 134 ++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index d4f2ee26..a7c59715 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -12,6 +12,7 @@ import ( "fmt" "log/slog" "math/big" + "os" "strings" "sync" "testing" @@ -135,6 +136,53 @@ func TestServerStartWithTLS(t *testing.T) { assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") } +func TestServerStartWithMTLS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles() + assert.NoError(t, err) + + config := server.AgentConfig{ + ServerConfig: server.ServerConfig{ + BaseConfig: server.BaseConfig{ + Host: "localhost", + Port: "0", + CertFile: string(clientCertFile), + KeyFile: string(clientKeyFile), + ServerCAFile: caCertFile, + }, + }, + } + + logBuffer := &ThreadSafeBuffer{} + logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) + qp := new(mocks.QuoteProvider) + authSvc := new(authmocks.Authenticator) + + srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + wg.Done() + err := srv.Start() + assert.NoError(t, err) + }() + + wg.Wait() + + time.Sleep(200 * time.Millisecond) + + cancel() + + time.Sleep(200 * time.Millisecond) + + logContent := logBuffer.String() + fmt.Println(logContent) + assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") +} + func TestFailedServerStartWithTLS(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -303,3 +351,89 @@ func (b *ThreadSafeBuffer) String() string { defer b.mu.Unlock() return b.buffer.String() } + +func createCertificatesFiles() (string, string, string, error) { + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + caTemplate := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})) + if err != nil { + return "", "", "", err + } + + clientKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", "", err + } + + clientTemplate := x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey) + if err != nil { + return "", "", "", err + } + + clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})) + if err != nil { + return "", "", "", err + } + + clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)})) + if err != nil { + return "", "", "", err + } + + return caCertFile, clientCertFile, clientKeyFile, nil +} + +func createTempFile(data []byte) (string, error) { + file, err := createTempFileHandle() + if err != nil { + return "", err + } + + _, err = file.Write(data) + if err != nil { + return "", err + } + + err = file.Close() + if err != nil { + return "", err + } + + return file.Name(), nil +} + +func createTempFileHandle() (*os.File, error) { + return os.CreateTemp("", "test") +}