aboutsummaryrefslogtreecommitdiff
path: root/sftp/request-server.go
diff options
context:
space:
mode:
Diffstat (limited to 'sftp/request-server.go')
-rw-r--r--sftp/request-server.go304
1 files changed, 304 insertions, 0 deletions
diff --git a/sftp/request-server.go b/sftp/request-server.go
new file mode 100644
index 0000000..5fa828b
--- /dev/null
+++ b/sftp/request-server.go
@@ -0,0 +1,304 @@
+package sftp
+
+import (
+ "context"
+ "errors"
+ "io"
+ "path"
+ "path/filepath"
+ "strconv"
+ "sync"
+)
+
+var maxTxPacket uint32 = 1 << 15
+
+// Handlers contains the 4 SFTP server request handlers.
+type Handlers struct {
+ FileGet FileReader
+ FilePut FileWriter
+ FileCmd FileCmder
+ FileList FileLister
+}
+
+// RequestServer abstracts the sftp protocol with an http request-like protocol
+type RequestServer struct {
+ Handlers Handlers
+
+ *serverConn
+ pktMgr *packetManager
+
+ mu sync.RWMutex
+ handleCount int
+ openRequests map[string]*Request
+}
+
+// A RequestServerOption is a function which applies configuration to a RequestServer.
+type RequestServerOption func(*RequestServer)
+
+// WithRSAllocator enable the allocator.
+// After processing a packet we keep in memory the allocated slices
+// and we reuse them for new packets.
+// The allocator is experimental
+func WithRSAllocator() RequestServerOption {
+ return func(rs *RequestServer) {
+ alloc := newAllocator()
+ rs.pktMgr.alloc = alloc
+ rs.conn.alloc = alloc
+ }
+}
+
+// NewRequestServer creates/allocates/returns new RequestServer.
+// Normally there will be one server per user-session.
+func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
+ svrConn := &serverConn{
+ conn: conn{
+ Reader: rwc,
+ WriteCloser: rwc,
+ },
+ }
+ rs := &RequestServer{
+ Handlers: h,
+
+ serverConn: svrConn,
+ pktMgr: newPktMgr(svrConn),
+
+ openRequests: make(map[string]*Request),
+ }
+
+ for _, o := range options {
+ o(rs)
+ }
+ return rs
+}
+
+// New Open packet/Request
+func (rs *RequestServer) nextRequest(r *Request) string {
+ rs.mu.Lock()
+ defer rs.mu.Unlock()
+
+ rs.handleCount++
+
+ r.handle = strconv.Itoa(rs.handleCount)
+ rs.openRequests[r.handle] = r
+
+ return r.handle
+}
+
+// Returns Request from openRequests, bool is false if it is missing.
+//
+// The Requests in openRequests work essentially as open file descriptors that
+// you can do different things with. What you are doing with it are denoted by
+// the first packet of that type (read/write/etc).
+func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
+ rs.mu.RLock()
+ defer rs.mu.RUnlock()
+
+ r, ok := rs.openRequests[handle]
+ return r, ok
+}
+
+// Close the Request and clear from openRequests map
+func (rs *RequestServer) closeRequest(handle string) error {
+ rs.mu.Lock()
+ defer rs.mu.Unlock()
+
+ if r, ok := rs.openRequests[handle]; ok {
+ delete(rs.openRequests, handle)
+ return r.close()
+ }
+
+ return EBADF
+}
+
+// Close the read/write/closer to trigger exiting the main server loop
+func (rs *RequestServer) Close() error { return rs.conn.Close() }
+
+func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
+ defer close(pktChan) // shuts down sftpServerWorkers
+
+ var err error
+ var pkt requestPacket
+ var pktType uint8
+ var pktBytes []byte
+
+ for {
+ pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID())
+ if err != nil {
+ // we don't care about releasing allocated pages here, the server will quit and the allocator freed
+ return err
+ }
+
+ pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
+ if err != nil {
+ switch {
+ case errors.Is(err, errUnknownExtendedPacket):
+ // do nothing
+ default:
+ debug("makePacket err: %v", err)
+ rs.conn.Close() // shuts down recvPacket
+ return err
+ }
+ }
+
+ pktChan <- rs.pktMgr.newOrderedRequest(pkt)
+ }
+}
+
+// Serve requests for user session
+func (rs *RequestServer) Serve() error {
+ defer func() {
+ if rs.pktMgr.alloc != nil {
+ rs.pktMgr.alloc.Free()
+ }
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ var wg sync.WaitGroup
+ runWorker := func(ch chan orderedRequest) {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := rs.packetWorker(ctx, ch); err != nil {
+ rs.conn.Close() // shuts down recvPacket
+ }
+ }()
+ }
+ pktChan := rs.pktMgr.workerChan(runWorker)
+
+ err := rs.serveLoop(pktChan)
+
+ wg.Wait() // wait for all workers to exit
+
+ rs.mu.Lock()
+ defer rs.mu.Unlock()
+
+ // make sure all open requests are properly closed
+ // (eg. possible on dropped connections, client crashes, etc.)
+ for handle, req := range rs.openRequests {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ req.transferError(err)
+
+ delete(rs.openRequests, handle)
+ req.close()
+ }
+
+ return err
+}
+
+func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error {
+ for pkt := range pktChan {
+ orderID := pkt.orderID()
+ if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok {
+ if epkt.SpecificPacket != nil {
+ pkt.requestPacket = epkt.SpecificPacket
+ }
+ }
+
+ var rpkt responsePacket
+ switch pkt := pkt.requestPacket.(type) {
+ case *sshFxInitPacket:
+ rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions}
+ case *sshFxpClosePacket:
+ handle := pkt.getHandle()
+ rpkt = statusFromError(pkt.ID, rs.closeRequest(handle))
+ case *sshFxpRealpathPacket:
+ var realPath string
+ if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok {
+ realPath = realPather.RealPath(pkt.getPath())
+ } else {
+ realPath = cleanPath(pkt.getPath())
+ }
+ rpkt = cleanPacketPath(pkt, realPath)
+ case *sshFxpOpendirPacket:
+ request := requestFromPacket(ctx, pkt)
+ handle := rs.nextRequest(request)
+ rpkt = request.opendir(rs.Handlers, pkt)
+ if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
+ // if we return an error we have to remove the handle from the active ones
+ rs.closeRequest(handle)
+ }
+ case *sshFxpOpenPacket:
+ request := requestFromPacket(ctx, pkt)
+ handle := rs.nextRequest(request)
+ rpkt = request.open(rs.Handlers, pkt)
+ if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
+ // if we return an error we have to remove the handle from the active ones
+ rs.closeRequest(handle)
+ }
+ case *sshFxpFstatPacket:
+ handle := pkt.getHandle()
+ request, ok := rs.getRequest(handle)
+ if !ok {
+ rpkt = statusFromError(pkt.ID, EBADF)
+ } else {
+ request = NewRequest("Stat", request.Filepath)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ }
+ case *sshFxpFsetstatPacket:
+ handle := pkt.getHandle()
+ request, ok := rs.getRequest(handle)
+ if !ok {
+ rpkt = statusFromError(pkt.ID, EBADF)
+ } else {
+ request = NewRequest("Setstat", request.Filepath)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ }
+ case *sshFxpExtendedPacketPosixRename:
+ request := NewRequest("PosixRename", pkt.Oldpath)
+ request.Target = pkt.Newpath
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ case *sshFxpExtendedPacketStatVFS:
+ request := NewRequest("StatVFS", pkt.Path)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ case hasHandle:
+ handle := pkt.getHandle()
+ request, ok := rs.getRequest(handle)
+ if !ok {
+ rpkt = statusFromError(pkt.id(), EBADF)
+ } else {
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ }
+ case hasPath:
+ request := requestFromPacket(ctx, pkt)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ request.close()
+ default:
+ rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
+ }
+
+ rs.pktMgr.readyPacket(
+ rs.pktMgr.newOrderedResponse(rpkt, orderID))
+ }
+ return nil
+}
+
+// clean and return name packet for file
+func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket {
+ return &sshFxpNamePacket{
+ ID: pkt.id(),
+ NameAttrs: []*sshFxpNameAttr{
+ {
+ Name: realPath,
+ LongName: realPath,
+ Attrs: emptyFileStat,
+ },
+ },
+ }
+}
+
+// Makes sure we have a clean POSIX (/) absolute path to work with
+func cleanPath(p string) string {
+ return cleanPathWithBase("/", p)
+}
+
+func cleanPathWithBase(base, p string) string {
+ p = filepath.ToSlash(filepath.Clean(p))
+ if !path.IsAbs(p) {
+ return path.Join(base, p)
+ }
+ return p
+}