aboutsummaryrefslogblamecommitdiff
path: root/executor/exec_utils.go
blob: 1a048eb233841d048a811759d8321353b61ab642 (plain) (tree)




























































































































































































































































































                                                                                                                                                          
package executor

import (
	"context"
	"fmt"
	"io"
	"os"
	"os/exec"
	"sync"
	"syscall"

	hclog "github.com/hashicorp/go-hclog"
	"github.com/hashicorp/nomad/plugins/drivers"
	dproto "github.com/hashicorp/nomad/plugins/drivers/proto"
)

// execHelper is a convenient wrapper for starting and executing commands, and handling their output
type execHelper struct {
	logger hclog.Logger

	// newTerminal function creates a tty appropriate for the command
	// The returned pty end of tty function is to be called after process start.
	newTerminal func() (pty func() (*os.File, error), tty *os.File, err error)

	// setTTY is a callback to configure the command with slave end of the tty of the terminal, when tty is enabled
	setTTY func(tty *os.File) error

	// setTTY is a callback to configure the command with std{in|out|err}, when tty is disabled
	setIO func(stdin io.Reader, stdout, stderr io.Writer) error

	// processStart starts the process, like `exec.Cmd.Start()`
	processStart func() error

	// processWait blocks until command terminates and returns its final state
	processWait func() (*os.ProcessState, error)
}

func (e *execHelper) run(ctx context.Context, tty bool, stream drivers.ExecTaskStream) error {
	if tty {
		return e.runTTY(ctx, stream)
	}
	return e.runNoTTY(ctx, stream)
}

func (e *execHelper) runTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
	ptyF, tty, err := e.newTerminal()
	if err != nil {
		return fmt.Errorf("failed to open a tty: %v", err)
	}
	defer tty.Close()

	if err := e.setTTY(tty); err != nil {
		return fmt.Errorf("failed to set command tty: %v", err)
	}
	if err := e.processStart(); err != nil {
		return fmt.Errorf("failed to start command: %v", err)
	}

	var wg sync.WaitGroup
	errCh := make(chan error, 3)

	pty, err := ptyF()
	if err != nil {
		return fmt.Errorf("failed to get pty: %v", err)
	}

	defer pty.Close()
	wg.Add(1)
	go handleStdin(e.logger, pty, stream, errCh)
	// when tty is on, stdout and stderr point to the same pty so only read once
	go handleStdout(e.logger, pty, &wg, stream.Send, errCh)

	ps, err := e.processWait()

	// force close streams to close out the stream copying goroutines
	tty.Close()

	// wait until we get all process output
	wg.Wait()

	// wait to flush out output
	stream.Send(cmdExitResult(ps, err))

	select {
	case cerr := <-errCh:
		return cerr
	default:
		return nil
	}
}

func (e *execHelper) runNoTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
	var sendLock sync.Mutex
	send := func(v *drivers.ExecTaskStreamingResponseMsg) error {
		sendLock.Lock()
		defer sendLock.Unlock()

		return stream.Send(v)
	}

	stdinPr, stdinPw := io.Pipe()
	stdoutPr, stdoutPw := io.Pipe()
	stderrPr, stderrPw := io.Pipe()

	defer stdoutPw.Close()
	defer stderrPw.Close()

	if err := e.setIO(stdinPr, stdoutPw, stderrPw); err != nil {
		return fmt.Errorf("failed to set command io: %v", err)
	}

	if err := e.processStart(); err != nil {
		return fmt.Errorf("failed to start command: %v", err)
	}

	var wg sync.WaitGroup
	errCh := make(chan error, 3)

	wg.Add(2)
	go handleStdin(e.logger, stdinPw, stream, errCh)
	go handleStdout(e.logger, stdoutPr, &wg, send, errCh)
	go handleStderr(e.logger, stderrPr, &wg, send, errCh)

	ps, err := e.processWait()

	// force close streams to close out the stream copying goroutines
	stdinPr.Close()
	stdoutPw.Close()
	stderrPw.Close()

	// wait until we get all process output
	wg.Wait()

	// wait to flush out output
	stream.Send(cmdExitResult(ps, err))

	select {
	case cerr := <-errCh:
		return cerr
	default:
		return nil
	}
}
func cmdExitResult(ps *os.ProcessState, err error) *drivers.ExecTaskStreamingResponseMsg {
	exitCode := -1

	if ps == nil {
		if ee, ok := err.(*exec.ExitError); ok {
			ps = ee.ProcessState
		}
	}

	if ps == nil {
		exitCode = -2
	} else if status, ok := ps.Sys().(syscall.WaitStatus); ok {
		exitCode = status.ExitStatus()
		if status.Signaled() {
			const exitSignalBase = 128
			signal := int(status.Signal())
			exitCode = exitSignalBase + signal
		}
	}

	return &drivers.ExecTaskStreamingResponseMsg{
		Exited: true,
		Result: &dproto.ExitResult{
			ExitCode: int32(exitCode),
		},
	}
}

func handleStdin(logger hclog.Logger, stdin io.WriteCloser, stream drivers.ExecTaskStream, errCh chan<- error) {
	for {
		m, err := stream.Recv()
		if isClosedError(err) {
			return
		} else if err != nil {
			errCh <- err
			return
		}

		if m.Stdin != nil {
			if len(m.Stdin.Data) != 0 {
				_, err := stdin.Write(m.Stdin.Data)
				if err != nil {
					errCh <- err
					return
				}
			}
			if m.Stdin.Close {
				stdin.Close()
			}
		} else if m.TtySize != nil {
			err := setTTYSize(stdin, m.TtySize.Height, m.TtySize.Width)
			if err != nil {
				errCh <- fmt.Errorf("failed to resize tty: %v", err)
				return
			}
		}
	}
}

func handleStdout(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
	defer wg.Done()

	buf := make([]byte, 4096)
	for {
		n, err := reader.Read(buf)
		// always send output first if we read something
		if n > 0 {
			if err := send(&drivers.ExecTaskStreamingResponseMsg{
				Stdout: &dproto.ExecTaskStreamingIOOperation{
					Data: buf[:n],
				},
			}); err != nil {
				errCh <- err
				return
			}
		}

		// then process error
		if isClosedError(err) {
			if err := send(&drivers.ExecTaskStreamingResponseMsg{
				Stdout: &dproto.ExecTaskStreamingIOOperation{
					Close: true,
				},
			}); err != nil {
				errCh <- err
				return
			}
			return
		} else if err != nil {
			errCh <- err
			return
		}

	}
}

func handleStderr(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
	defer wg.Done()

	buf := make([]byte, 4096)
	for {
		n, err := reader.Read(buf)
		// always send output first if we read something
		if n > 0 {
			if err := send(&drivers.ExecTaskStreamingResponseMsg{
				Stderr: &dproto.ExecTaskStreamingIOOperation{
					Data: buf[:n],
				},
			}); err != nil {
				errCh <- err
				return
			}
		}

		// then process error
		if isClosedError(err) {
			if err := send(&drivers.ExecTaskStreamingResponseMsg{
				Stderr: &dproto.ExecTaskStreamingIOOperation{
					Close: true,
				},
			}); err != nil {
				errCh <- err
				return
			}
			return
		} else if err != nil {
			errCh <- err
			return
		}

	}
}

func isClosedError(err error) bool {
	if err == nil {
		return false
	}

	return err == io.EOF ||
		err == io.ErrClosedPipe ||
		isUnixEIOErr(err)
}