diff options
Diffstat (limited to 'executor/exec_utils.go')
-rw-r--r-- | executor/exec_utils.go | 285 |
1 files changed, 285 insertions, 0 deletions
diff --git a/executor/exec_utils.go b/executor/exec_utils.go new file mode 100644 index 0000000..1a048eb --- /dev/null +++ b/executor/exec_utils.go @@ -0,0 +1,285 @@ +package executor + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "sync" + "syscall" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/nomad/plugins/drivers" + dproto "github.com/hashicorp/nomad/plugins/drivers/proto" +) + +// execHelper is a convenient wrapper for starting and executing commands, and handling their output +type execHelper struct { + logger hclog.Logger + + // newTerminal function creates a tty appropriate for the command + // The returned pty end of tty function is to be called after process start. + newTerminal func() (pty func() (*os.File, error), tty *os.File, err error) + + // setTTY is a callback to configure the command with slave end of the tty of the terminal, when tty is enabled + setTTY func(tty *os.File) error + + // setTTY is a callback to configure the command with std{in|out|err}, when tty is disabled + setIO func(stdin io.Reader, stdout, stderr io.Writer) error + + // processStart starts the process, like `exec.Cmd.Start()` + processStart func() error + + // processWait blocks until command terminates and returns its final state + processWait func() (*os.ProcessState, error) +} + +func (e *execHelper) run(ctx context.Context, tty bool, stream drivers.ExecTaskStream) error { + if tty { + return e.runTTY(ctx, stream) + } + return e.runNoTTY(ctx, stream) +} + +func (e *execHelper) runTTY(ctx context.Context, stream drivers.ExecTaskStream) error { + ptyF, tty, err := e.newTerminal() + if err != nil { + return fmt.Errorf("failed to open a tty: %v", err) + } + defer tty.Close() + + if err := e.setTTY(tty); err != nil { + return fmt.Errorf("failed to set command tty: %v", err) + } + if err := e.processStart(); err != nil { + return fmt.Errorf("failed to start command: %v", err) + } + + var wg sync.WaitGroup + errCh := make(chan error, 3) + + pty, err := ptyF() + if err != nil { + return fmt.Errorf("failed to get pty: %v", err) + } + + defer pty.Close() + wg.Add(1) + go handleStdin(e.logger, pty, stream, errCh) + // when tty is on, stdout and stderr point to the same pty so only read once + go handleStdout(e.logger, pty, &wg, stream.Send, errCh) + + ps, err := e.processWait() + + // force close streams to close out the stream copying goroutines + tty.Close() + + // wait until we get all process output + wg.Wait() + + // wait to flush out output + stream.Send(cmdExitResult(ps, err)) + + select { + case cerr := <-errCh: + return cerr + default: + return nil + } +} + +func (e *execHelper) runNoTTY(ctx context.Context, stream drivers.ExecTaskStream) error { + var sendLock sync.Mutex + send := func(v *drivers.ExecTaskStreamingResponseMsg) error { + sendLock.Lock() + defer sendLock.Unlock() + + return stream.Send(v) + } + + stdinPr, stdinPw := io.Pipe() + stdoutPr, stdoutPw := io.Pipe() + stderrPr, stderrPw := io.Pipe() + + defer stdoutPw.Close() + defer stderrPw.Close() + + if err := e.setIO(stdinPr, stdoutPw, stderrPw); err != nil { + return fmt.Errorf("failed to set command io: %v", err) + } + + if err := e.processStart(); err != nil { + return fmt.Errorf("failed to start command: %v", err) + } + + var wg sync.WaitGroup + errCh := make(chan error, 3) + + wg.Add(2) + go handleStdin(e.logger, stdinPw, stream, errCh) + go handleStdout(e.logger, stdoutPr, &wg, send, errCh) + go handleStderr(e.logger, stderrPr, &wg, send, errCh) + + ps, err := e.processWait() + + // force close streams to close out the stream copying goroutines + stdinPr.Close() + stdoutPw.Close() + stderrPw.Close() + + // wait until we get all process output + wg.Wait() + + // wait to flush out output + stream.Send(cmdExitResult(ps, err)) + + select { + case cerr := <-errCh: + return cerr + default: + return nil + } +} +func cmdExitResult(ps *os.ProcessState, err error) *drivers.ExecTaskStreamingResponseMsg { + exitCode := -1 + + if ps == nil { + if ee, ok := err.(*exec.ExitError); ok { + ps = ee.ProcessState + } + } + + if ps == nil { + exitCode = -2 + } else if status, ok := ps.Sys().(syscall.WaitStatus); ok { + exitCode = status.ExitStatus() + if status.Signaled() { + const exitSignalBase = 128 + signal := int(status.Signal()) + exitCode = exitSignalBase + signal + } + } + + return &drivers.ExecTaskStreamingResponseMsg{ + Exited: true, + Result: &dproto.ExitResult{ + ExitCode: int32(exitCode), + }, + } +} + +func handleStdin(logger hclog.Logger, stdin io.WriteCloser, stream drivers.ExecTaskStream, errCh chan<- error) { + for { + m, err := stream.Recv() + if isClosedError(err) { + return + } else if err != nil { + errCh <- err + return + } + + if m.Stdin != nil { + if len(m.Stdin.Data) != 0 { + _, err := stdin.Write(m.Stdin.Data) + if err != nil { + errCh <- err + return + } + } + if m.Stdin.Close { + stdin.Close() + } + } else if m.TtySize != nil { + err := setTTYSize(stdin, m.TtySize.Height, m.TtySize.Width) + if err != nil { + errCh <- fmt.Errorf("failed to resize tty: %v", err) + return + } + } + } +} + +func handleStdout(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) { + defer wg.Done() + + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + // always send output first if we read something + if n > 0 { + if err := send(&drivers.ExecTaskStreamingResponseMsg{ + Stdout: &dproto.ExecTaskStreamingIOOperation{ + Data: buf[:n], + }, + }); err != nil { + errCh <- err + return + } + } + + // then process error + if isClosedError(err) { + if err := send(&drivers.ExecTaskStreamingResponseMsg{ + Stdout: &dproto.ExecTaskStreamingIOOperation{ + Close: true, + }, + }); err != nil { + errCh <- err + return + } + return + } else if err != nil { + errCh <- err + return + } + + } +} + +func handleStderr(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) { + defer wg.Done() + + buf := make([]byte, 4096) + for { + n, err := reader.Read(buf) + // always send output first if we read something + if n > 0 { + if err := send(&drivers.ExecTaskStreamingResponseMsg{ + Stderr: &dproto.ExecTaskStreamingIOOperation{ + Data: buf[:n], + }, + }); err != nil { + errCh <- err + return + } + } + + // then process error + if isClosedError(err) { + if err := send(&drivers.ExecTaskStreamingResponseMsg{ + Stderr: &dproto.ExecTaskStreamingIOOperation{ + Close: true, + }, + }); err != nil { + errCh <- err + return + } + return + } else if err != nil { + errCh <- err + return + } + + } +} + +func isClosedError(err error) bool { + if err == nil { + return false + } + + return err == io.EOF || + err == io.ErrClosedPipe || + isUnixEIOErr(err) +} |