-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Enhance SSH key handling and add integration tests
- Loading branch information
Showing
7 changed files
with
406 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
package ssh | ||
|
||
import ( | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/x509" | ||
"encoding/pem" | ||
"os" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/testcontainers/testcontainers-go" | ||
"github.com/testcontainers/testcontainers-go/wait" | ||
"golang.org/x/crypto/ssh" | ||
) | ||
|
||
func TestSSHClient(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
// Generate SSH key pair | ||
privateKey, publicKey, err := generateSSHKeyPair() | ||
assert.NoError(t, err) | ||
|
||
// Start an SSH server container | ||
req := testcontainers.ContainerRequest{ | ||
Image: "linuxserver/openssh-server", | ||
ExposedPorts: []string{"2222/tcp"}, | ||
Env: map[string]string{ | ||
"PUID": "1000", | ||
"PGID": "1000", | ||
"PUBLIC_KEY": publicKey, | ||
"SUDO_ACCESS": "true", | ||
"PASSWORD_ACCESS": "false", | ||
"USER_NAME": "abc", | ||
}, | ||
WaitingFor: wait.ForListeningPort("2222/tcp").WithStartupTimeout(2 * time.Minute), | ||
} | ||
|
||
container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ | ||
ContainerRequest: req, | ||
Started: true, | ||
}) | ||
assert.NoError(t, err) | ||
defer func() { | ||
err := container.Terminate(ctx) | ||
assert.NoError(t, err) | ||
}() | ||
|
||
// Get the host and port | ||
host, err := container.Host(ctx) | ||
assert.NoError(t, err) | ||
mappedPort, err := container.MappedPort(ctx, "2222") | ||
assert.NoError(t, err) | ||
port := mappedPort.Port() | ||
|
||
// Read the private key | ||
key := []byte(privateKey) | ||
|
||
// Wait for the SSH server to start | ||
time.Sleep(5 * time.Second) | ||
|
||
// Connect to the SSH server | ||
client, err := ConnectWithUser(host+":"+port, "abc", key) | ||
assert.NoError(t, err) | ||
defer client.Close() | ||
|
||
// Run a command | ||
err = client.RunCommand("echo 'Hello, World!'") | ||
assert.NoError(t, err) | ||
|
||
// Upload a file | ||
localFilePath := "test_upload.txt" | ||
remoteFilePath := "/tmp/test_upload.txt" | ||
|
||
err = os.WriteFile(localFilePath, []byte("test data"), 0644) | ||
assert.NoError(t, err) | ||
defer os.Remove(localFilePath) | ||
|
||
err = client.UploadFile(localFilePath, remoteFilePath) | ||
assert.NoError(t, err) | ||
|
||
// Verify the file was uploaded | ||
session, err := client.NewSession() | ||
assert.NoError(t, err) | ||
defer session.Close() | ||
|
||
output, err := session.CombinedOutput("cat " + remoteFilePath) | ||
assert.NoError(t, err) | ||
assert.Equal(t, "test data", string(output)) | ||
} | ||
|
||
func generateSSHKeyPair() (privateKey string, publicKey string, err error) { | ||
// Generate the key pair | ||
privateKeyObj, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
return "", "", err | ||
} | ||
|
||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKeyObj) | ||
privateKeyPem := pem.EncodeToMemory(&pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: privateKeyBytes, | ||
}) | ||
|
||
publicKeyObj, err := ssh.NewPublicKey(&privateKeyObj.PublicKey) | ||
if err != nil { | ||
return "", "", err | ||
} | ||
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKeyObj) | ||
|
||
return string(privateKeyPem), string(publicKeyBytes), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package utils | ||
|
||
import ( | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestFindSSHKey(t *testing.T) { | ||
// Create temporary SSH keys in a temp directory | ||
tempDir := t.TempDir() | ||
sshKeyPath = tempDir | ||
|
||
keyContent := []byte("test-key") | ||
keyNames := []string{"id_rsa", "id_ecdsa", "id_ed25519"} | ||
|
||
// Write test keys | ||
for _, name := range keyNames { | ||
keyPath := filepath.Join(tempDir, name) | ||
err := os.WriteFile(keyPath, keyContent, 0600) | ||
assert.NoError(t, err) | ||
} | ||
|
||
// Override the home directory to point to tempDir | ||
originalHome := os.Getenv("HOME") | ||
defer os.Setenv("HOME", originalHome) | ||
|
||
_ = os.Setenv("HOME", tempDir) | ||
|
||
// Test with no keyPath, should find id_rsa first | ||
key, err := FindSSHKey("") | ||
assert.NoError(t, err) | ||
assert.Equal(t, keyContent, key) | ||
|
||
// Test with specified keyPath | ||
specifiedKeyPath := filepath.Join(tempDir, "custom_key") | ||
err = os.WriteFile(specifiedKeyPath, keyContent, 0600) | ||
assert.NoError(t, err) | ||
|
||
key, err = FindSSHKey(specifiedKeyPath) | ||
assert.NoError(t, err) | ||
assert.Equal(t, keyContent, key) | ||
|
||
// Test when no keys are found | ||
_ = os.RemoveAll(tempDir) | ||
key, err = FindSSHKey("") | ||
assert.Error(t, err) | ||
assert.Nil(t, key) | ||
} |