aboutsummaryrefslogblamecommitdiff
path: root/sftp/request-server.go
blob: 5fa828bb3b6443119ca02a365d6f85653c9fe968 (plain) (tree)















































































































































































































































































































                                                                                                                           
package sftp

import (
	"context"
	"errors"
	"io"
	"path"
	"path/filepath"
	"strconv"
	"sync"
)

var maxTxPacket uint32 = 1 << 15

// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
	FileGet  FileReader
	FilePut  FileWriter
	FileCmd  FileCmder
	FileList FileLister
}

// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
	Handlers Handlers

	*serverConn
	pktMgr *packetManager

	mu           sync.RWMutex
	handleCount  int
	openRequests map[string]*Request
}

// A RequestServerOption is a function which applies configuration to a RequestServer.
type RequestServerOption func(*RequestServer)

// WithRSAllocator enable the allocator.
// After processing a packet we keep in memory the allocated slices
// and we reuse them for new packets.
// The allocator is experimental
func WithRSAllocator() RequestServerOption {
	return func(rs *RequestServer) {
		alloc := newAllocator()
		rs.pktMgr.alloc = alloc
		rs.conn.alloc = alloc
	}
}

// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
	svrConn := &serverConn{
		conn: conn{
			Reader:      rwc,
			WriteCloser: rwc,
		},
	}
	rs := &RequestServer{
		Handlers: h,

		serverConn: svrConn,
		pktMgr:     newPktMgr(svrConn),

		openRequests: make(map[string]*Request),
	}

	for _, o := range options {
		o(rs)
	}
	return rs
}

// New Open packet/Request
func (rs *RequestServer) nextRequest(r *Request) string {
	rs.mu.Lock()
	defer rs.mu.Unlock()

	rs.handleCount++

	r.handle = strconv.Itoa(rs.handleCount)
	rs.openRequests[r.handle] = r

	return r.handle
}

// Returns Request from openRequests, bool is false if it is missing.
//
// The Requests in openRequests work essentially as open file descriptors that
// you can do different things with. What you are doing with it are denoted by
// the first packet of that type (read/write/etc).
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
	rs.mu.RLock()
	defer rs.mu.RUnlock()

	r, ok := rs.openRequests[handle]
	return r, ok
}

// Close the Request and clear from openRequests map
func (rs *RequestServer) closeRequest(handle string) error {
	rs.mu.Lock()
	defer rs.mu.Unlock()

	if r, ok := rs.openRequests[handle]; ok {
		delete(rs.openRequests, handle)
		return r.close()
	}

	return EBADF
}

// Close the read/write/closer to trigger exiting the main server loop
func (rs *RequestServer) Close() error { return rs.conn.Close() }

func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
	defer close(pktChan) // shuts down sftpServerWorkers

	var err error
	var pkt requestPacket
	var pktType uint8
	var pktBytes []byte

	for {
		pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID())
		if err != nil {
			// we don't care about releasing allocated pages here, the server will quit and the allocator freed
			return err
		}

		pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
		if err != nil {
			switch {
			case errors.Is(err, errUnknownExtendedPacket):
				// do nothing
			default:
				debug("makePacket err: %v", err)
				rs.conn.Close() // shuts down recvPacket
				return err
			}
		}

		pktChan <- rs.pktMgr.newOrderedRequest(pkt)
	}
}

// Serve requests for user session
func (rs *RequestServer) Serve() error {
	defer func() {
		if rs.pktMgr.alloc != nil {
			rs.pktMgr.alloc.Free()
		}
	}()

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	var wg sync.WaitGroup
	runWorker := func(ch chan orderedRequest) {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if err := rs.packetWorker(ctx, ch); err != nil {
				rs.conn.Close() // shuts down recvPacket
			}
		}()
	}
	pktChan := rs.pktMgr.workerChan(runWorker)

	err := rs.serveLoop(pktChan)

	wg.Wait() // wait for all workers to exit

	rs.mu.Lock()
	defer rs.mu.Unlock()

	// make sure all open requests are properly closed
	// (eg. possible on dropped connections, client crashes, etc.)
	for handle, req := range rs.openRequests {
		if err == io.EOF {
			err = io.ErrUnexpectedEOF
		}
		req.transferError(err)

		delete(rs.openRequests, handle)
		req.close()
	}

	return err
}

func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error {
	for pkt := range pktChan {
		orderID := pkt.orderID()
		if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok {
			if epkt.SpecificPacket != nil {
				pkt.requestPacket = epkt.SpecificPacket
			}
		}

		var rpkt responsePacket
		switch pkt := pkt.requestPacket.(type) {
		case *sshFxInitPacket:
			rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions}
		case *sshFxpClosePacket:
			handle := pkt.getHandle()
			rpkt = statusFromError(pkt.ID, rs.closeRequest(handle))
		case *sshFxpRealpathPacket:
			var realPath string
			if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok {
				realPath = realPather.RealPath(pkt.getPath())
			} else {
				realPath = cleanPath(pkt.getPath())
			}
			rpkt = cleanPacketPath(pkt, realPath)
		case *sshFxpOpendirPacket:
			request := requestFromPacket(ctx, pkt)
			handle := rs.nextRequest(request)
			rpkt = request.opendir(rs.Handlers, pkt)
			if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
				// if we return an error we have to remove the handle from the active ones
				rs.closeRequest(handle)
			}
		case *sshFxpOpenPacket:
			request := requestFromPacket(ctx, pkt)
			handle := rs.nextRequest(request)
			rpkt = request.open(rs.Handlers, pkt)
			if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
				// if we return an error we have to remove the handle from the active ones
				rs.closeRequest(handle)
			}
		case *sshFxpFstatPacket:
			handle := pkt.getHandle()
			request, ok := rs.getRequest(handle)
			if !ok {
				rpkt = statusFromError(pkt.ID, EBADF)
			} else {
				request = NewRequest("Stat", request.Filepath)
				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
			}
		case *sshFxpFsetstatPacket:
			handle := pkt.getHandle()
			request, ok := rs.getRequest(handle)
			if !ok {
				rpkt = statusFromError(pkt.ID, EBADF)
			} else {
				request = NewRequest("Setstat", request.Filepath)
				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
			}
		case *sshFxpExtendedPacketPosixRename:
			request := NewRequest("PosixRename", pkt.Oldpath)
			request.Target = pkt.Newpath
			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
		case *sshFxpExtendedPacketStatVFS:
			request := NewRequest("StatVFS", pkt.Path)
			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
		case hasHandle:
			handle := pkt.getHandle()
			request, ok := rs.getRequest(handle)
			if !ok {
				rpkt = statusFromError(pkt.id(), EBADF)
			} else {
				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
			}
		case hasPath:
			request := requestFromPacket(ctx, pkt)
			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
			request.close()
		default:
			rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
		}

		rs.pktMgr.readyPacket(
			rs.pktMgr.newOrderedResponse(rpkt, orderID))
	}
	return nil
}

// clean and return name packet for file
func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket {
	return &sshFxpNamePacket{
		ID: pkt.id(),
		NameAttrs: []*sshFxpNameAttr{
			{
				Name:     realPath,
				LongName: realPath,
				Attrs:    emptyFileStat,
			},
		},
	}
}

// Makes sure we have a clean POSIX (/) absolute path to work with
func cleanPath(p string) string {
	return cleanPathWithBase("/", p)
}

func cleanPathWithBase(base, p string) string {
	p = filepath.ToSlash(filepath.Clean(p))
	if !path.IsAbs(p) {
		return path.Join(base, p)
	}
	return p
}