aboutsummaryrefslogblamecommitdiff
path: root/sftp/packet_manager.go
blob: 5aeb72bae4136d093fc339e142986acfb6bf7f00 (plain) (tree)




























































































































































































































                                                                                         
package sftp

/*
  Imported from: https://github.com/pkg/sftp
 */

import (
	"encoding"
	"sort"
	"sync"
)

// The goal of the packetManager is to keep the outgoing packets in the same
// order as the incoming as is requires by section 7 of the RFC.

type packetManager struct {
	requests    chan orderedPacket
	responses   chan orderedPacket
	fini        chan struct{}
	incoming    orderedPackets
	outgoing    orderedPackets
	sender      packetSender // connection object
	working     *sync.WaitGroup
	packetCount uint32
	// it is not nil if the allocator is enabled
	alloc *allocator
}

type packetSender interface {
	sendPacket(encoding.BinaryMarshaler) error
}

func newPktMgr(sender packetSender) *packetManager {
	s := &packetManager{
		requests:  make(chan orderedPacket, SftpServerWorkerCount),
		responses: make(chan orderedPacket, SftpServerWorkerCount),
		fini:      make(chan struct{}),
		incoming:  make([]orderedPacket, 0, SftpServerWorkerCount),
		outgoing:  make([]orderedPacket, 0, SftpServerWorkerCount),
		sender:    sender,
		working:   &sync.WaitGroup{},
	}
	go s.controller()
	return s
}

//// packet ordering
func (s *packetManager) newOrderID() uint32 {
	s.packetCount++
	return s.packetCount
}

// returns the next orderID without incrementing it.
// This is used before receiving a new packet, with the allocator enabled, to associate
// the slice allocated for the received packet with the orderID that will be used to mark
// the allocated slices for reuse once the request is served
func (s *packetManager) getNextOrderID() uint32 {
	return s.packetCount + 1
}

type orderedRequest struct {
	requestPacket
	orderid uint32
}

func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
	return orderedRequest{requestPacket: p, orderid: s.newOrderID()}
}
func (p orderedRequest) orderID() uint32       { return p.orderid }
func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid }

type orderedResponse struct {
	responsePacket
	orderid uint32
}

func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
) orderedResponse {
	return orderedResponse{responsePacket: p, orderid: id}
}
func (p orderedResponse) orderID() uint32       { return p.orderid }
func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid }

type orderedPacket interface {
	id() uint32
	orderID() uint32
}
type orderedPackets []orderedPacket

func (o orderedPackets) Sort() {
	sort.Slice(o, func(i, j int) bool {
		return o[i].orderID() < o[j].orderID()
	})
}

//// packet registry
// register incoming packets to be handled
func (s *packetManager) incomingPacket(pkt orderedRequest) {
	s.working.Add(1)
	s.requests <- pkt
}

// register outgoing packets as being ready
func (s *packetManager) readyPacket(pkt orderedResponse) {
	s.responses <- pkt
	s.working.Done()
}

// shut down packetManager controller
func (s *packetManager) close() {
	// pause until current packets are processed
	s.working.Wait()
	close(s.fini)
}

// Passed a worker function, returns a channel for incoming packets.
// Keep process packet responses in the order they are received while
// maximizing throughput of file transfers.
func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
) chan orderedRequest {
	// multiple workers for faster read/writes
	rwChan := make(chan orderedRequest, SftpServerWorkerCount)
	for i := 0; i < SftpServerWorkerCount; i++ {
		runWorker(rwChan)
	}

	// single worker to enforce sequential processing of everything else
	cmdChan := make(chan orderedRequest)
	runWorker(cmdChan)

	pktChan := make(chan orderedRequest, SftpServerWorkerCount)
	go func() {
		for pkt := range pktChan {
			switch pkt.requestPacket.(type) {
			case *sshFxpReadPacket, *sshFxpWritePacket:
				s.incomingPacket(pkt)
				rwChan <- pkt
				continue
			case *sshFxpClosePacket:
				// wait for reads/writes to finish when file is closed
				// incomingPacket() call must occur after this
				s.working.Wait()
			}
			s.incomingPacket(pkt)
			// all non-RW use sequential cmdChan
			cmdChan <- pkt
		}
		close(rwChan)
		close(cmdChan)
		s.close()
	}()

	return pktChan
}

// process packets
func (s *packetManager) controller() {
	for {
		select {
		case pkt := <-s.requests:
			debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID())
			s.incoming = append(s.incoming, pkt)
			s.incoming.Sort()
		case pkt := <-s.responses:
			debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID())
			s.outgoing = append(s.outgoing, pkt)
			s.outgoing.Sort()
		case <-s.fini:
			return
		}
		s.maybeSendPackets()
	}
}

// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
	for {
		if len(s.outgoing) == 0 || len(s.incoming) == 0 {
			debug("break! -- outgoing: %v; incoming: %v",
				len(s.outgoing), len(s.incoming))
			break
		}
		out := s.outgoing[0]
		in := s.incoming[0]
		// debug("incoming: %v", ids(s.incoming))
		// debug("outgoing: %v", ids(s.outgoing))
		if in.orderID() == out.orderID() {
			debug("Sending packet: %v", out.id())
			s.sender.sendPacket(out.(encoding.BinaryMarshaler))
			if s.alloc != nil {
				// mark for reuse the slices allocated for this request
				s.alloc.ReleasePages(in.orderID())
			}
			// pop off heads
			copy(s.incoming, s.incoming[1:])            // shift left
			s.incoming[len(s.incoming)-1] = nil         // clear last
			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
			copy(s.outgoing, s.outgoing[1:])            // shift left
			s.outgoing[len(s.outgoing)-1] = nil         // clear last
			s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
		} else {
			break
		}
	}
}

// func oids(o []orderedPacket) []uint32 {
// 	res := make([]uint32, 0, len(o))
// 	for _, v := range o {
// 		res = append(res, v.orderId())
// 	}
// 	return res
// }
// func ids(o []orderedPacket) []uint32 {
// 	res := make([]uint32, 0, len(o))
// 	for _, v := range o {
// 		res = append(res, v.id())
// 	}
// 	return res
// }