diff options
Diffstat (limited to 'executor/grpc_client.go')
-rw-r--r-- | executor/grpc_client.go | 267 |
1 files changed, 267 insertions, 0 deletions
diff --git a/executor/grpc_client.go b/executor/grpc_client.go new file mode 100644 index 0000000..7ab2dbf --- /dev/null +++ b/executor/grpc_client.go @@ -0,0 +1,267 @@ +package executor + +import ( + "context" + "fmt" + "io" + "os" + "syscall" + "time" + + "github.com/LK4D4/joincontext" + "github.com/golang/protobuf/ptypes" + hclog "github.com/hashicorp/go-hclog" + cstructs "github.com/hashicorp/nomad/client/structs" + "github.com/hashicorp/nomad/drivers/shared/executor/proto" + "github.com/hashicorp/nomad/helper/pluginutils/grpcutils" + "github.com/hashicorp/nomad/plugins/drivers" + dproto "github.com/hashicorp/nomad/plugins/drivers/proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var _ Executor = (*grpcExecutorClient)(nil) + +type grpcExecutorClient struct { + client proto.ExecutorClient + logger hclog.Logger + + // doneCtx is close when the plugin exits + doneCtx context.Context +} + +func (c *grpcExecutorClient) Launch(cmd *ExecCommand) (*ProcessState, error) { + ctx := context.Background() + req := &proto.LaunchRequest{ + Cmd: cmd.Cmd, + Args: cmd.Args, + Resources: drivers.ResourcesToProto(cmd.Resources), + StdoutPath: cmd.StdoutPath, + StderrPath: cmd.StderrPath, + Env: cmd.Env, + User: cmd.User, + TaskDir: cmd.TaskDir, + ResourceLimits: cmd.ResourceLimits, + BasicProcessCgroup: cmd.BasicProcessCgroup, + NoPivotRoot: cmd.NoPivotRoot, + Mounts: drivers.MountsToProto(cmd.Mounts), + Devices: drivers.DevicesToProto(cmd.Devices), + NetworkIsolation: drivers.NetworkIsolationSpecToProto(cmd.NetworkIsolation), + DefaultPidMode: cmd.ModePID, + DefaultIpcMode: cmd.ModeIPC, + Capabilities: cmd.Capabilities, + } + resp, err := c.client.Launch(ctx, req) + if err != nil { + return nil, err + } + + ps, err := processStateFromProto(resp.Process) + if err != nil { + return nil, err + } + return ps, nil +} + +func (c *grpcExecutorClient) Wait(ctx context.Context) (*ProcessState, error) { + // Join the passed context and the shutdown context + ctx, _ = joincontext.Join(ctx, c.doneCtx) + + resp, err := c.client.Wait(ctx, &proto.WaitRequest{}) + if err != nil { + return nil, err + } + + ps, err := processStateFromProto(resp.Process) + if err != nil { + return nil, err + } + + return ps, nil +} + +func (c *grpcExecutorClient) Shutdown(signal string, gracePeriod time.Duration) error { + ctx := context.Background() + req := &proto.ShutdownRequest{ + Signal: signal, + GracePeriod: gracePeriod.Nanoseconds(), + } + if _, err := c.client.Shutdown(ctx, req); err != nil { + return err + } + + return nil +} + +func (c *grpcExecutorClient) UpdateResources(r *drivers.Resources) error { + ctx := context.Background() + req := &proto.UpdateResourcesRequest{Resources: drivers.ResourcesToProto(r)} + if _, err := c.client.UpdateResources(ctx, req); err != nil { + return err + } + + return nil +} + +func (c *grpcExecutorClient) Version() (*ExecutorVersion, error) { + ctx := context.Background() + resp, err := c.client.Version(ctx, &proto.VersionRequest{}) + if err != nil { + return nil, err + } + return &ExecutorVersion{Version: resp.Version}, nil +} + +func (c *grpcExecutorClient) Stats(ctx context.Context, interval time.Duration) (<-chan *cstructs.TaskResourceUsage, error) { + stream, err := c.client.Stats(ctx, &proto.StatsRequest{ + Interval: int64(interval), + }) + if err != nil { + return nil, err + } + + ch := make(chan *cstructs.TaskResourceUsage) + go c.handleStats(ctx, stream, ch) + return ch, nil +} + +func (c *grpcExecutorClient) handleStats(ctx context.Context, stream proto.Executor_StatsClient, ch chan<- *cstructs.TaskResourceUsage) { + defer close(ch) + for { + resp, err := stream.Recv() + if ctx.Err() != nil { + // Context canceled; exit gracefully + return + } + + if err == io.EOF || + status.Code(err) == codes.Unavailable || + status.Code(err) == codes.Canceled || + err == context.Canceled { + c.logger.Trace("executor Stats stream closed", "msg", err) + return + } else if err != nil { + c.logger.Warn("failed to receive Stats executor RPC stream, closing stream", "error", err) + return + } + + stats, err := drivers.TaskStatsFromProto(resp.Stats) + if err != nil { + c.logger.Error("failed to decode stats from RPC", "error", err, "stats", resp.Stats) + continue + } + + select { + case ch <- stats: + case <-ctx.Done(): + return + } + } +} + +func (c *grpcExecutorClient) Signal(s os.Signal) error { + ctx := context.Background() + sig, ok := s.(syscall.Signal) + if !ok { + return fmt.Errorf("unsupported signal type: %q", s.String()) + } + req := &proto.SignalRequest{ + Signal: int32(sig), + } + if _, err := c.client.Signal(ctx, req); err != nil { + return err + } + + return nil +} + +func (c *grpcExecutorClient) Exec(deadline time.Time, cmd string, args []string) ([]byte, int, error) { + ctx := context.Background() + pbDeadline, err := ptypes.TimestampProto(deadline) + if err != nil { + return nil, 0, err + } + req := &proto.ExecRequest{ + Deadline: pbDeadline, + Cmd: cmd, + Args: args, + } + + resp, err := c.client.Exec(ctx, req) + if err != nil { + return nil, 0, err + } + + return resp.Output, int(resp.ExitCode), nil +} + +func (c *grpcExecutorClient) ExecStreaming(ctx context.Context, + command []string, + tty bool, + execStream drivers.ExecTaskStream) error { + + err := c.execStreaming(ctx, command, tty, execStream) + if err != nil { + return grpcutils.HandleGrpcErr(err, c.doneCtx) + } + return nil +} + +func (c *grpcExecutorClient) execStreaming(ctx context.Context, + command []string, + tty bool, + execStream drivers.ExecTaskStream) error { + + stream, err := c.client.ExecStreaming(ctx) + if err != nil { + return err + } + + err = stream.Send(&dproto.ExecTaskStreamingRequest{ + Setup: &dproto.ExecTaskStreamingRequest_Setup{ + Command: command, + Tty: tty, + }, + }) + if err != nil { + return err + } + + errCh := make(chan error, 1) + go func() { + for { + m, err := execStream.Recv() + if err == io.EOF { + return + } else if err != nil { + errCh <- err + return + } + + if err := stream.Send(m); err != nil { + errCh <- err + return + } + + } + }() + + for { + select { + case err := <-errCh: + return err + default: + } + + m, err := stream.Recv() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if err := execStream.Send(m); err != nil { + return err + } + } +} |