aboutsummaryrefslogtreecommitdiff
path: root/executor/exec_utils.go
diff options
context:
space:
mode:
Diffstat (limited to 'executor/exec_utils.go')
-rw-r--r--executor/exec_utils.go285
1 files changed, 285 insertions, 0 deletions
diff --git a/executor/exec_utils.go b/executor/exec_utils.go
new file mode 100644
index 0000000..1a048eb
--- /dev/null
+++ b/executor/exec_utils.go
@@ -0,0 +1,285 @@
+package executor
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "sync"
+ "syscall"
+
+ hclog "github.com/hashicorp/go-hclog"
+ "github.com/hashicorp/nomad/plugins/drivers"
+ dproto "github.com/hashicorp/nomad/plugins/drivers/proto"
+)
+
+// execHelper is a convenient wrapper for starting and executing commands, and handling their output
+type execHelper struct {
+ logger hclog.Logger
+
+ // newTerminal function creates a tty appropriate for the command
+ // The returned pty end of tty function is to be called after process start.
+ newTerminal func() (pty func() (*os.File, error), tty *os.File, err error)
+
+ // setTTY is a callback to configure the command with slave end of the tty of the terminal, when tty is enabled
+ setTTY func(tty *os.File) error
+
+ // setTTY is a callback to configure the command with std{in|out|err}, when tty is disabled
+ setIO func(stdin io.Reader, stdout, stderr io.Writer) error
+
+ // processStart starts the process, like `exec.Cmd.Start()`
+ processStart func() error
+
+ // processWait blocks until command terminates and returns its final state
+ processWait func() (*os.ProcessState, error)
+}
+
+func (e *execHelper) run(ctx context.Context, tty bool, stream drivers.ExecTaskStream) error {
+ if tty {
+ return e.runTTY(ctx, stream)
+ }
+ return e.runNoTTY(ctx, stream)
+}
+
+func (e *execHelper) runTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
+ ptyF, tty, err := e.newTerminal()
+ if err != nil {
+ return fmt.Errorf("failed to open a tty: %v", err)
+ }
+ defer tty.Close()
+
+ if err := e.setTTY(tty); err != nil {
+ return fmt.Errorf("failed to set command tty: %v", err)
+ }
+ if err := e.processStart(); err != nil {
+ return fmt.Errorf("failed to start command: %v", err)
+ }
+
+ var wg sync.WaitGroup
+ errCh := make(chan error, 3)
+
+ pty, err := ptyF()
+ if err != nil {
+ return fmt.Errorf("failed to get pty: %v", err)
+ }
+
+ defer pty.Close()
+ wg.Add(1)
+ go handleStdin(e.logger, pty, stream, errCh)
+ // when tty is on, stdout and stderr point to the same pty so only read once
+ go handleStdout(e.logger, pty, &wg, stream.Send, errCh)
+
+ ps, err := e.processWait()
+
+ // force close streams to close out the stream copying goroutines
+ tty.Close()
+
+ // wait until we get all process output
+ wg.Wait()
+
+ // wait to flush out output
+ stream.Send(cmdExitResult(ps, err))
+
+ select {
+ case cerr := <-errCh:
+ return cerr
+ default:
+ return nil
+ }
+}
+
+func (e *execHelper) runNoTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
+ var sendLock sync.Mutex
+ send := func(v *drivers.ExecTaskStreamingResponseMsg) error {
+ sendLock.Lock()
+ defer sendLock.Unlock()
+
+ return stream.Send(v)
+ }
+
+ stdinPr, stdinPw := io.Pipe()
+ stdoutPr, stdoutPw := io.Pipe()
+ stderrPr, stderrPw := io.Pipe()
+
+ defer stdoutPw.Close()
+ defer stderrPw.Close()
+
+ if err := e.setIO(stdinPr, stdoutPw, stderrPw); err != nil {
+ return fmt.Errorf("failed to set command io: %v", err)
+ }
+
+ if err := e.processStart(); err != nil {
+ return fmt.Errorf("failed to start command: %v", err)
+ }
+
+ var wg sync.WaitGroup
+ errCh := make(chan error, 3)
+
+ wg.Add(2)
+ go handleStdin(e.logger, stdinPw, stream, errCh)
+ go handleStdout(e.logger, stdoutPr, &wg, send, errCh)
+ go handleStderr(e.logger, stderrPr, &wg, send, errCh)
+
+ ps, err := e.processWait()
+
+ // force close streams to close out the stream copying goroutines
+ stdinPr.Close()
+ stdoutPw.Close()
+ stderrPw.Close()
+
+ // wait until we get all process output
+ wg.Wait()
+
+ // wait to flush out output
+ stream.Send(cmdExitResult(ps, err))
+
+ select {
+ case cerr := <-errCh:
+ return cerr
+ default:
+ return nil
+ }
+}
+func cmdExitResult(ps *os.ProcessState, err error) *drivers.ExecTaskStreamingResponseMsg {
+ exitCode := -1
+
+ if ps == nil {
+ if ee, ok := err.(*exec.ExitError); ok {
+ ps = ee.ProcessState
+ }
+ }
+
+ if ps == nil {
+ exitCode = -2
+ } else if status, ok := ps.Sys().(syscall.WaitStatus); ok {
+ exitCode = status.ExitStatus()
+ if status.Signaled() {
+ const exitSignalBase = 128
+ signal := int(status.Signal())
+ exitCode = exitSignalBase + signal
+ }
+ }
+
+ return &drivers.ExecTaskStreamingResponseMsg{
+ Exited: true,
+ Result: &dproto.ExitResult{
+ ExitCode: int32(exitCode),
+ },
+ }
+}
+
+func handleStdin(logger hclog.Logger, stdin io.WriteCloser, stream drivers.ExecTaskStream, errCh chan<- error) {
+ for {
+ m, err := stream.Recv()
+ if isClosedError(err) {
+ return
+ } else if err != nil {
+ errCh <- err
+ return
+ }
+
+ if m.Stdin != nil {
+ if len(m.Stdin.Data) != 0 {
+ _, err := stdin.Write(m.Stdin.Data)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ }
+ if m.Stdin.Close {
+ stdin.Close()
+ }
+ } else if m.TtySize != nil {
+ err := setTTYSize(stdin, m.TtySize.Height, m.TtySize.Width)
+ if err != nil {
+ errCh <- fmt.Errorf("failed to resize tty: %v", err)
+ return
+ }
+ }
+ }
+}
+
+func handleStdout(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
+ defer wg.Done()
+
+ buf := make([]byte, 4096)
+ for {
+ n, err := reader.Read(buf)
+ // always send output first if we read something
+ if n > 0 {
+ if err := send(&drivers.ExecTaskStreamingResponseMsg{
+ Stdout: &dproto.ExecTaskStreamingIOOperation{
+ Data: buf[:n],
+ },
+ }); err != nil {
+ errCh <- err
+ return
+ }
+ }
+
+ // then process error
+ if isClosedError(err) {
+ if err := send(&drivers.ExecTaskStreamingResponseMsg{
+ Stdout: &dproto.ExecTaskStreamingIOOperation{
+ Close: true,
+ },
+ }); err != nil {
+ errCh <- err
+ return
+ }
+ return
+ } else if err != nil {
+ errCh <- err
+ return
+ }
+
+ }
+}
+
+func handleStderr(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
+ defer wg.Done()
+
+ buf := make([]byte, 4096)
+ for {
+ n, err := reader.Read(buf)
+ // always send output first if we read something
+ if n > 0 {
+ if err := send(&drivers.ExecTaskStreamingResponseMsg{
+ Stderr: &dproto.ExecTaskStreamingIOOperation{
+ Data: buf[:n],
+ },
+ }); err != nil {
+ errCh <- err
+ return
+ }
+ }
+
+ // then process error
+ if isClosedError(err) {
+ if err := send(&drivers.ExecTaskStreamingResponseMsg{
+ Stderr: &dproto.ExecTaskStreamingIOOperation{
+ Close: true,
+ },
+ }); err != nil {
+ errCh <- err
+ return
+ }
+ return
+ } else if err != nil {
+ errCh <- err
+ return
+ }
+
+ }
+}
+
+func isClosedError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ return err == io.EOF ||
+ err == io.ErrClosedPipe ||
+ isUnixEIOErr(err)
+}