Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: toolbox exec session management #1732

Merged
merged 5 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 2 additions & 33 deletions pkg/agent/ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"io"
"os"
"os/exec"
"strings"
"syscall"
"unsafe"

"github.com/creack/pty"
"github.com/daytonaio/daytona/pkg/agent/ssh/config"
"github.com/daytonaio/daytona/pkg/common"
"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -81,7 +81,7 @@ func (s *Server) Start() error {
}

func (s *Server) handlePty(session ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
shell := s.getShell()
shell := common.GetShell()
cmd := exec.Command(shell)

cmd.Dir = s.ProjectDir
Expand Down Expand Up @@ -233,37 +233,6 @@ func (s *Server) osSignalFrom(sig ssh.Signal) os.Signal {
}
}

func (s *Server) getShell() string {
out, err := exec.Command("sh", "-c", "grep '^[^#]' /etc/shells").Output()
if err != nil {
return "sh"
}

if strings.Contains(string(out), "/usr/bin/zsh") {
return "/usr/bin/zsh"
}

if strings.Contains(string(out), "/bin/zsh") {
return "/bin/zsh"
}

if strings.Contains(string(out), "/usr/bin/bash") {
return "/usr/bin/bash"
}

if strings.Contains(string(out), "/bin/bash") {
return "/bin/bash"
}

shellEnv, shellSet := os.LookupEnv("SHELL")

if shellSet {
return shellEnv
}

return "sh"
}

func (s *Server) sftpHandler(session ssh.Session) {
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
Expand Down
87 changes: 45 additions & 42 deletions pkg/agent/toolbox/process/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,63 +14,66 @@ import (
"github.com/gin-gonic/gin"
)

func ExecuteCommand(c *gin.Context) {
var request ExecuteRequest
if err := c.ShouldBindJSON(&request); err != nil {
c.AbortWithError(400, errors.New("command is required"))
return
}
func ExecuteCommand(projectDir string) func(c *gin.Context) {
return func(c *gin.Context) {
var request ExecuteRequest
if err := c.ShouldBindJSON(&request); err != nil {
c.AbortWithError(400, errors.New("command is required"))
return
}

cmdParts := parseCommand(request.Command)
if len(cmdParts) == 0 {
c.AbortWithError(400, errors.New("empty command"))
return
}
cmdParts := parseCommand(request.Command)
if len(cmdParts) == 0 {
c.AbortWithError(400, errors.New("empty command"))
return
}

cmd := exec.Command(cmdParts[0], cmdParts[1:]...)
cmd := exec.Command(cmdParts[0], cmdParts[1:]...)
cmd.Dir = projectDir

// set maximum execution time
timeout := 10 * time.Second
if request.Timeout != nil && *request.Timeout > 0 {
timeout = time.Duration(*request.Timeout) * time.Second
}
// set maximum execution time
timeout := 10 * time.Second
if request.Timeout != nil && *request.Timeout > 0 {
timeout = time.Duration(*request.Timeout) * time.Second
}

timeoutReached := false
timer := time.AfterFunc(timeout, func() {
timeoutReached = true
if cmd.Process != nil {
// kill the process group
err := cmd.Process.Kill()
if err != nil {
log.Error(err)
return
}
}
})
defer timer.Stop()

timeoutReached := false
timer := time.AfterFunc(timeout, func() {
timeoutReached = true
if cmd.Process != nil {
// kill the process group
err := cmd.Process.Kill()
if err != nil {
log.Error(err)
output, err := cmd.CombinedOutput()
if err != nil {
if timeoutReached {
c.AbortWithError(408, errors.New("command execution timeout"))
return
}
c.AbortWithError(400, err)
return
}
})
defer timer.Stop()

output, err := cmd.CombinedOutput()
if err != nil {
if timeoutReached {
c.AbortWithError(408, errors.New("command execution timeout"))
if cmd.ProcessState == nil {
c.JSON(200, ExecuteResponse{
Code: -1,
Result: string(output),
})
return
}
c.AbortWithError(400, err)
return
}

if cmd.ProcessState == nil {
c.JSON(200, ExecuteResponse{
Code: -1,
Code: cmd.ProcessState.ExitCode(),
Result: string(output),
})
return
}

c.JSON(200, ExecuteResponse{
Code: cmd.ProcessState.ExitCode(),
Result: string(output),
})
}

// parseCommand splits a command string properly handling quotes
Expand Down
153 changes: 153 additions & 0 deletions pkg/agent/toolbox/process/session/execute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright 2024 Daytona Platforms Inc.
// SPDX-License-Identifier: Apache-2.0

package session

import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"

"github.com/daytonaio/daytona/internal/util"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)

func SessionExecuteCommand(configDir string) func(c *gin.Context) {
return func(c *gin.Context) {
sessionId := c.Param("sessionId")

var request SessionExecuteRequest
if err := c.ShouldBindJSON(&request); err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}

session, ok := sessions[sessionId]
if !ok {
c.AbortWithError(http.StatusNotFound, errors.New("session not found"))
return
}

var cmdId *string
var logFile *os.File

cmdId = util.Pointer(uuid.NewString())

command := &Command{
Id: *cmdId,
Command: request.Command,
}
session.commands[*cmdId] = command

logFilePath := command.LogFilePath(session.Dir(configDir))

err := os.MkdirAll(filepath.Dir(logFilePath), 0755)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}

logFile, err = os.Create(logFilePath)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}

cmdToExec := fmt.Sprintf("%s > %s 2>&1 ; echo \"DTN_EXIT: $?\" >> %s\n", request.Command, logFile.Name(), logFile.Name())

type execResult struct {
out string
err error
exitCode *int
}
resultChan := make(chan execResult)

go func() {
out := ""
defer close(resultChan)

logChan := make(chan []byte)
errChan := make(chan error)

go util.ReadLog(context.Background(), logFile, true, logChan, errChan)

defer logFile.Close()

for {
select {
case logEntry := <-logChan:
logEntry = bytes.Trim(logEntry, "\x00")
if len(logEntry) == 0 {
continue
}
exitCode, line := extractExitCode(string(logEntry))
out += line

if exitCode != nil {
sessions[sessionId].commands[*cmdId].ExitCode = exitCode
resultChan <- execResult{out: out, exitCode: exitCode, err: nil}
return
}
case err := <-errChan:
if err != nil {
resultChan <- execResult{out: out, exitCode: nil, err: err}
return
}
}
}
}()

_, err = session.stdinWriter.Write([]byte(cmdToExec))
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}

if request.Async {
c.JSON(http.StatusAccepted, SessionExecuteResponse{
CommandId: cmdId,
})
return
}

result := <-resultChan
if result.err != nil {
c.AbortWithError(http.StatusBadRequest, result.err)
return
}

c.JSON(http.StatusOK, SessionExecuteResponse{
CommandId: cmdId,
Output: &result.out,
ExitCode: result.exitCode,
})
}
}

func extractExitCode(output string) (*int, string) {
var exitCode *int

regex := regexp.MustCompile(`DTN_EXIT: (\d+)\n`)
matches := regex.FindStringSubmatch(output)
if len(matches) > 1 {
code, err := strconv.Atoi(matches[1])
if err != nil {
return nil, output
}
exitCode = &code
}

if exitCode != nil {
output = strings.Replace(output, fmt.Sprintf("DTN_EXIT: %d\n", *exitCode), "", 1)
}

return exitCode, output
}
74 changes: 74 additions & 0 deletions pkg/agent/toolbox/process/session/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2024 Daytona Platforms Inc.
// SPDX-License-Identifier: Apache-2.0

package session

import (
"errors"
"net/http"
"os"

"github.com/daytonaio/daytona/internal/util"
"github.com/daytonaio/daytona/pkg/api/controllers/log"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)

func GetSessionCommandLogs(configDir string) func(c *gin.Context) {
return func(c *gin.Context) {
sessionId := c.Param("sessionId")
cmdId := c.Param("commandId")

session, ok := sessions[sessionId]
if !ok {
c.AbortWithError(http.StatusNotFound, errors.New("session not found"))
return
}

command, ok := sessions[sessionId].commands[cmdId]
if !ok {
c.AbortWithError(http.StatusNotFound, errors.New("command not found"))
return
}

path := command.LogFilePath(session.Dir(configDir))

if c.Request.Header.Get("Upgrade") == "websocket" {
logFile, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
c.AbortWithError(http.StatusNotFound, errors.New("log file not found"))
return
}
c.AbortWithError(http.StatusInternalServerError, err)
return
}
defer logFile.Close()
log.ReadLog(c, logFile, util.ReadLog, func(conn *websocket.Conn, messages chan []byte, errors chan error) {
for {
msg := <-messages
_, output := extractExitCode(string(msg))
err := conn.WriteMessage(websocket.TextMessage, []byte(output))
if err != nil {
errors <- err
break
}
}
})
return
}

content, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
c.AbortWithError(http.StatusNotFound, errors.New("log file not found"))
return
}
c.AbortWithError(http.StatusInternalServerError, err)
return
}

_, output := extractExitCode(string(content))
c.String(http.StatusOK, output)
}
}
Loading
Loading