aboutsummaryrefslogtreecommitdiff
path: root/sftp/request.go
diff options
context:
space:
mode:
Diffstat (limited to 'sftp/request.go')
-rw-r--r--sftp/request.go628
1 files changed, 628 insertions, 0 deletions
diff --git a/sftp/request.go b/sftp/request.go
new file mode 100644
index 0000000..c6da4b6
--- /dev/null
+++ b/sftp/request.go
@@ -0,0 +1,628 @@
+package sftp
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+ "sync"
+ "syscall"
+)
+
+// MaxFilelist is the max number of files to return in a readdir batch.
+var MaxFilelist int64 = 100
+
+// state encapsulates the reader/writer/readdir from handlers.
+type state struct {
+ mu sync.RWMutex
+
+ writerAt io.WriterAt
+ readerAt io.ReaderAt
+ writerAtReaderAt WriterAtReaderAt
+ listerAt ListerAt
+ lsoffset int64
+}
+
+// copy returns a shallow copy the state.
+// This is broken out to specific fields,
+// because we have to copy around the mutex in state.
+func (s *state) copy() state {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return state{
+ writerAt: s.writerAt,
+ readerAt: s.readerAt,
+ writerAtReaderAt: s.writerAtReaderAt,
+ listerAt: s.listerAt,
+ lsoffset: s.lsoffset,
+ }
+}
+
+func (s *state) setReaderAt(rd io.ReaderAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.readerAt = rd
+}
+
+func (s *state) getReaderAt() io.ReaderAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.readerAt
+}
+
+func (s *state) setWriterAt(rd io.WriterAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.writerAt = rd
+}
+
+func (s *state) getWriterAt() io.WriterAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.writerAt
+}
+
+func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.writerAtReaderAt = rw
+}
+
+func (s *state) getWriterAtReaderAt() WriterAtReaderAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.writerAtReaderAt
+}
+
+func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.readerAt, s.writerAt, s.writerAtReaderAt
+}
+
+// Returns current offset for file list
+func (s *state) lsNext() int64 {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.lsoffset
+}
+
+// Increases next offset
+func (s *state) lsInc(offset int64) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.lsoffset += offset
+}
+
+// manage file read/write state
+func (s *state) setListerAt(la ListerAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.listerAt = la
+}
+
+func (s *state) getListerAt() ListerAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.listerAt
+}
+
+// Request contains the data and state for the incoming service request.
+type Request struct {
+ // Get, Put, Setstat, Stat, Rename, Remove
+ // Rmdir, Mkdir, List, Readlink, Link, Symlink
+ Method string
+ Filepath string
+ Flags uint32
+ Attrs []byte // convert to sub-struct
+ Target string // for renames and sym-links
+ handle string
+
+ // reader/writer/readdir from handlers
+ state
+
+ // context lasts duration of request
+ ctx context.Context
+ cancelCtx context.CancelFunc
+}
+
+// NewRequest creates a new Request object.
+func NewRequest(method, path string) *Request {
+ return &Request{
+ Method: method,
+ Filepath: cleanPath(path),
+ }
+}
+
+// copy returns a shallow copy of existing request.
+// This is broken out to specific fields,
+// because we have to copy around the mutex in state.
+func (r *Request) copy() *Request {
+ return &Request{
+ Method: r.Method,
+ Filepath: r.Filepath,
+ Flags: r.Flags,
+ Attrs: r.Attrs,
+ Target: r.Target,
+ handle: r.handle,
+
+ state: r.state.copy(),
+
+ ctx: r.ctx,
+ cancelCtx: r.cancelCtx,
+ }
+}
+
+// New Request initialized based on packet data
+func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
+ method := requestMethod(pkt)
+ request := NewRequest(method, pkt.getPath())
+ request.ctx, request.cancelCtx = context.WithCancel(ctx)
+
+ switch p := pkt.(type) {
+ case *sshFxpOpenPacket:
+ request.Flags = p.Pflags
+ case *sshFxpSetstatPacket:
+ request.Flags = p.Flags
+ request.Attrs = p.Attrs.([]byte)
+ case *sshFxpRenamePacket:
+ request.Target = cleanPath(p.Newpath)
+ case *sshFxpSymlinkPacket:
+ // NOTE: given a POSIX compliant signature: symlink(target, linkpath string)
+ // this makes Request.Target the linkpath, and Request.Filepath the target.
+ request.Target = cleanPath(p.Linkpath)
+ case *sshFxpExtendedPacketHardlink:
+ request.Target = cleanPath(p.Newpath)
+ }
+ return request
+}
+
+// Context returns the request's context. To change the context,
+// use WithContext.
+//
+// The returned context is always non-nil; it defaults to the
+// background context.
+//
+// For incoming server requests, the context is canceled when the
+// request is complete or the client's connection closes.
+func (r *Request) Context() context.Context {
+ if r.ctx != nil {
+ return r.ctx
+ }
+ return context.Background()
+}
+
+// WithContext returns a copy of r with its context changed to ctx.
+// The provided ctx must be non-nil.
+func (r *Request) WithContext(ctx context.Context) *Request {
+ if ctx == nil {
+ panic("nil context")
+ }
+ r2 := r.copy()
+ r2.ctx = ctx
+ r2.cancelCtx = nil
+ return r2
+}
+
+// Close reader/writer if possible
+func (r *Request) close() error {
+ defer func() {
+ if r.cancelCtx != nil {
+ r.cancelCtx()
+ }
+ }()
+
+ rd, wr, rw := r.getAllReaderWriters()
+
+ var err error
+
+ // Close errors on a Writer are far more likely to be the important one.
+ // As they can be information that there was a loss of data.
+ if c, ok := wr.(io.Closer); ok {
+ if err2 := c.Close(); err == nil {
+ // update error if it is still nil
+ err = err2
+ }
+ }
+
+ if c, ok := rw.(io.Closer); ok {
+ if err2 := c.Close(); err == nil {
+ // update error if it is still nil
+ err = err2
+
+ r.setWriterAtReaderAt(nil)
+ }
+ }
+
+ if c, ok := rd.(io.Closer); ok {
+ if err2 := c.Close(); err == nil {
+ // update error if it is still nil
+ err = err2
+ }
+ }
+
+ return err
+}
+
+// Notify transfer error if any
+func (r *Request) transferError(err error) {
+ if err == nil {
+ return
+ }
+
+ rd, wr, rw := r.getAllReaderWriters()
+
+ if t, ok := wr.(TransferError); ok {
+ t.TransferError(err)
+ }
+
+ if t, ok := rw.(TransferError); ok {
+ t.TransferError(err)
+ }
+
+ if t, ok := rd.(TransferError); ok {
+ t.TransferError(err)
+ }
+}
+
+// called from worker to handle packet/request
+func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ switch r.Method {
+ case "Get":
+ return fileget(handlers.FileGet, r, pkt, alloc, orderID)
+ case "Put":
+ return fileput(handlers.FilePut, r, pkt, alloc, orderID)
+ case "Open":
+ return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
+ case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
+ return filecmd(handlers.FileCmd, r, pkt)
+ case "List":
+ return filelist(handlers.FileList, r, pkt)
+ case "Stat", "Lstat", "Readlink":
+ return filestat(handlers.FileList, r, pkt)
+ default:
+ return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method))
+ }
+}
+
+// Additional initialization for Open packets
+func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
+ flags := r.Pflags()
+
+ id := pkt.id()
+
+ switch {
+ case flags.Write, flags.Append, flags.Creat, flags.Trunc:
+ if flags.Read {
+ if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok {
+ r.Method = "Open"
+ rw, err := openFileWriter.OpenFile(r)
+ if err != nil {
+ return statusFromError(id, err)
+ }
+
+ r.setWriterAtReaderAt(rw)
+
+ return &sshFxpHandlePacket{
+ ID: id,
+ Handle: r.handle,
+ }
+ }
+ }
+
+ r.Method = "Put"
+ wr, err := h.FilePut.Filewrite(r)
+ if err != nil {
+ return statusFromError(id, err)
+ }
+
+ r.setWriterAt(wr)
+
+ case flags.Read:
+ r.Method = "Get"
+ rd, err := h.FileGet.Fileread(r)
+ if err != nil {
+ return statusFromError(id, err)
+ }
+
+ r.setReaderAt(rd)
+
+ default:
+ return statusFromError(id, errors.New("bad file flags"))
+ }
+
+ return &sshFxpHandlePacket{
+ ID: id,
+ Handle: r.handle,
+ }
+}
+
+func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
+ r.Method = "List"
+ la, err := h.FileList.Filelist(r)
+ if err != nil {
+ return statusFromError(pkt.id(), wrapPathError(r.Filepath, err))
+ }
+
+ r.setListerAt(la)
+
+ return &sshFxpHandlePacket{
+ ID: pkt.id(),
+ Handle: r.handle,
+ }
+}
+
+// wrap FileReader handler
+func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ rd := r.getReaderAt()
+ if rd == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected read packet"))
+ }
+
+ data, offset, _ := packetData(pkt, alloc, orderID)
+
+ n, err := rd.ReadAt(data, offset)
+ // only return EOF error if no data left to read
+ if err != nil && (err != io.EOF || n == 0) {
+ return statusFromError(pkt.id(), err)
+ }
+
+ return &sshFxpDataPacket{
+ ID: pkt.id(),
+ Length: uint32(n),
+ Data: data[:n],
+ }
+}
+
+// wrap FileWriter handler
+func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ wr := r.getWriterAt()
+ if wr == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected write packet"))
+ }
+
+ data, offset, _ := packetData(pkt, alloc, orderID)
+
+ _, err := wr.WriteAt(data, offset)
+ return statusFromError(pkt.id(), err)
+}
+
+// wrap OpenFileWriter handler
+func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ rw := r.getWriterAtReaderAt()
+ if rw == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
+ }
+
+ switch p := pkt.(type) {
+ case *sshFxpReadPacket:
+ data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
+
+ n, err := rw.ReadAt(data, offset)
+ // only return EOF error if no data left to read
+ if err != nil && (err != io.EOF || n == 0) {
+ return statusFromError(pkt.id(), err)
+ }
+
+ return &sshFxpDataPacket{
+ ID: pkt.id(),
+ Length: uint32(n),
+ Data: data[:n],
+ }
+
+ case *sshFxpWritePacket:
+ data, offset := p.Data, int64(p.Offset)
+
+ _, err := rw.WriteAt(data, offset)
+ return statusFromError(pkt.id(), err)
+
+ default:
+ return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write"))
+ }
+}
+
+// file data for additional read/write packets
+func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
+ switch p := p.(type) {
+ case *sshFxpReadPacket:
+ return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
+ case *sshFxpWritePacket:
+ return p.Data, int64(p.Offset), p.Length
+ }
+ return
+}
+
+// wrap FileCmder handler
+func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
+ switch p := pkt.(type) {
+ case *sshFxpFsetstatPacket:
+ r.Flags = p.Flags
+ r.Attrs = p.Attrs.([]byte)
+ }
+
+ switch r.Method {
+ case "PosixRename":
+ if posixRenamer, ok := h.(PosixRenameFileCmder); ok {
+ err := posixRenamer.PosixRename(r)
+ return statusFromError(pkt.id(), err)
+ }
+
+ // PosixRenameFileCmder not implemented handle this request as a Rename
+ r.Method = "Rename"
+ err := h.Filecmd(r)
+ return statusFromError(pkt.id(), err)
+
+ case "StatVFS":
+ if statVFSCmdr, ok := h.(StatVFSFileCmder); ok {
+ stat, err := statVFSCmdr.StatVFS(r)
+ if err != nil {
+ return statusFromError(pkt.id(), err)
+ }
+ stat.ID = pkt.id()
+ return stat
+ }
+
+ return statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
+ }
+
+ err := h.Filecmd(r)
+ return statusFromError(pkt.id(), err)
+}
+
+// wrap FileLister handler
+func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {
+ lister := r.getListerAt()
+ if lister == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected dir packet"))
+ }
+
+ offset := r.lsNext()
+ finfo := make([]os.FileInfo, MaxFilelist)
+ n, err := lister.ListAt(finfo, offset)
+ r.lsInc(int64(n))
+ // ignore EOF as we only return it when there are no results
+ finfo = finfo[:n] // avoid need for nil tests below
+
+ switch r.Method {
+ case "List":
+ if err != nil && (err != io.EOF || n == 0) {
+ return statusFromError(pkt.id(), err)
+ }
+
+ nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo))
+
+ // If the type conversion fails, we get untyped `nil`,
+ // which is handled by not looking up any names.
+ idLookup, _ := h.(NameLookupFileLister)
+
+ for _, fi := range finfo {
+ nameAttrs = append(nameAttrs, &sshFxpNameAttr{
+ Name: fi.Name(),
+ LongName: runLs(idLookup, fi),
+ Attrs: []interface{}{fi},
+ })
+ }
+
+ return &sshFxpNamePacket{
+ ID: pkt.id(),
+ NameAttrs: nameAttrs,
+ }
+
+ default:
+ err = fmt.Errorf("unexpected method: %s", r.Method)
+ return statusFromError(pkt.id(), err)
+ }
+}
+
+func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
+ var lister ListerAt
+ var err error
+
+ if r.Method == "Lstat" {
+ if lstatFileLister, ok := h.(LstatFileLister); ok {
+ lister, err = lstatFileLister.Lstat(r)
+ } else {
+ // LstatFileLister not implemented handle this request as a Stat
+ r.Method = "Stat"
+ lister, err = h.Filelist(r)
+ }
+ } else {
+ lister, err = h.Filelist(r)
+ }
+ if err != nil {
+ return statusFromError(pkt.id(), err)
+ }
+ finfo := make([]os.FileInfo, 1)
+ n, err := lister.ListAt(finfo, 0)
+ finfo = finfo[:n] // avoid need for nil tests below
+
+ switch r.Method {
+ case "Stat", "Lstat":
+ if err != nil && err != io.EOF {
+ return statusFromError(pkt.id(), err)
+ }
+ if n == 0 {
+ err = &os.PathError{
+ Op: strings.ToLower(r.Method),
+ Path: r.Filepath,
+ Err: syscall.ENOENT,
+ }
+ return statusFromError(pkt.id(), err)
+ }
+ return &sshFxpStatResponse{
+ ID: pkt.id(),
+ info: finfo[0],
+ }
+ case "Readlink":
+ if err != nil && err != io.EOF {
+ return statusFromError(pkt.id(), err)
+ }
+ if n == 0 {
+ err = &os.PathError{
+ Op: "readlink",
+ Path: r.Filepath,
+ Err: syscall.ENOENT,
+ }
+ return statusFromError(pkt.id(), err)
+ }
+ filename := finfo[0].Name()
+ return &sshFxpNamePacket{
+ ID: pkt.id(),
+ NameAttrs: []*sshFxpNameAttr{
+ {
+ Name: filename,
+ LongName: filename,
+ Attrs: emptyFileStat,
+ },
+ },
+ }
+ default:
+ err = fmt.Errorf("unexpected method: %s", r.Method)
+ return statusFromError(pkt.id(), err)
+ }
+}
+
+// init attributes of request object from packet data
+func requestMethod(p requestPacket) (method string) {
+ switch p.(type) {
+ case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket:
+ // set in open() above
+ case *sshFxpOpendirPacket, *sshFxpReaddirPacket:
+ // set in opendir() above
+ case *sshFxpSetstatPacket, *sshFxpFsetstatPacket:
+ method = "Setstat"
+ case *sshFxpRenamePacket:
+ method = "Rename"
+ case *sshFxpSymlinkPacket:
+ method = "Symlink"
+ case *sshFxpRemovePacket:
+ method = "Remove"
+ case *sshFxpStatPacket, *sshFxpFstatPacket:
+ method = "Stat"
+ case *sshFxpLstatPacket:
+ method = "Lstat"
+ case *sshFxpRmdirPacket:
+ method = "Rmdir"
+ case *sshFxpReadlinkPacket:
+ method = "Readlink"
+ case *sshFxpMkdirPacket:
+ method = "Mkdir"
+ case *sshFxpExtendedPacketHardlink:
+ method = "Link"
+ }
+ return method
+}