diff options
Diffstat (limited to 'sftp/packet_manager.go')
-rw-r--r-- | sftp/packet_manager.go | 221 |
1 files changed, 221 insertions, 0 deletions
diff --git a/sftp/packet_manager.go b/sftp/packet_manager.go new file mode 100644 index 0000000..5aeb72b --- /dev/null +++ b/sftp/packet_manager.go @@ -0,0 +1,221 @@ +package sftp + +/* + Imported from: https://github.com/pkg/sftp + */ + +import ( + "encoding" + "sort" + "sync" +) + +// The goal of the packetManager is to keep the outgoing packets in the same +// order as the incoming as is requires by section 7 of the RFC. + +type packetManager struct { + requests chan orderedPacket + responses chan orderedPacket + fini chan struct{} + incoming orderedPackets + outgoing orderedPackets + sender packetSender // connection object + working *sync.WaitGroup + packetCount uint32 + // it is not nil if the allocator is enabled + alloc *allocator +} + +type packetSender interface { + sendPacket(encoding.BinaryMarshaler) error +} + +func newPktMgr(sender packetSender) *packetManager { + s := &packetManager{ + requests: make(chan orderedPacket, SftpServerWorkerCount), + responses: make(chan orderedPacket, SftpServerWorkerCount), + fini: make(chan struct{}), + incoming: make([]orderedPacket, 0, SftpServerWorkerCount), + outgoing: make([]orderedPacket, 0, SftpServerWorkerCount), + sender: sender, + working: &sync.WaitGroup{}, + } + go s.controller() + return s +} + +//// packet ordering +func (s *packetManager) newOrderID() uint32 { + s.packetCount++ + return s.packetCount +} + +// returns the next orderID without incrementing it. +// This is used before receiving a new packet, with the allocator enabled, to associate +// the slice allocated for the received packet with the orderID that will be used to mark +// the allocated slices for reuse once the request is served +func (s *packetManager) getNextOrderID() uint32 { + return s.packetCount + 1 +} + +type orderedRequest struct { + requestPacket + orderid uint32 +} + +func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest { + return orderedRequest{requestPacket: p, orderid: s.newOrderID()} +} +func (p orderedRequest) orderID() uint32 { return p.orderid } +func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid } + +type orderedResponse struct { + responsePacket + orderid uint32 +} + +func (s *packetManager) newOrderedResponse(p responsePacket, id uint32, +) orderedResponse { + return orderedResponse{responsePacket: p, orderid: id} +} +func (p orderedResponse) orderID() uint32 { return p.orderid } +func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid } + +type orderedPacket interface { + id() uint32 + orderID() uint32 +} +type orderedPackets []orderedPacket + +func (o orderedPackets) Sort() { + sort.Slice(o, func(i, j int) bool { + return o[i].orderID() < o[j].orderID() + }) +} + +//// packet registry +// register incoming packets to be handled +func (s *packetManager) incomingPacket(pkt orderedRequest) { + s.working.Add(1) + s.requests <- pkt +} + +// register outgoing packets as being ready +func (s *packetManager) readyPacket(pkt orderedResponse) { + s.responses <- pkt + s.working.Done() +} + +// shut down packetManager controller +func (s *packetManager) close() { + // pause until current packets are processed + s.working.Wait() + close(s.fini) +} + +// Passed a worker function, returns a channel for incoming packets. +// Keep process packet responses in the order they are received while +// maximizing throughput of file transfers. +func (s *packetManager) workerChan(runWorker func(chan orderedRequest), +) chan orderedRequest { + // multiple workers for faster read/writes + rwChan := make(chan orderedRequest, SftpServerWorkerCount) + for i := 0; i < SftpServerWorkerCount; i++ { + runWorker(rwChan) + } + + // single worker to enforce sequential processing of everything else + cmdChan := make(chan orderedRequest) + runWorker(cmdChan) + + pktChan := make(chan orderedRequest, SftpServerWorkerCount) + go func() { + for pkt := range pktChan { + switch pkt.requestPacket.(type) { + case *sshFxpReadPacket, *sshFxpWritePacket: + s.incomingPacket(pkt) + rwChan <- pkt + continue + case *sshFxpClosePacket: + // wait for reads/writes to finish when file is closed + // incomingPacket() call must occur after this + s.working.Wait() + } + s.incomingPacket(pkt) + // all non-RW use sequential cmdChan + cmdChan <- pkt + } + close(rwChan) + close(cmdChan) + s.close() + }() + + return pktChan +} + +// process packets +func (s *packetManager) controller() { + for { + select { + case pkt := <-s.requests: + debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID()) + s.incoming = append(s.incoming, pkt) + s.incoming.Sort() + case pkt := <-s.responses: + debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID()) + s.outgoing = append(s.outgoing, pkt) + s.outgoing.Sort() + case <-s.fini: + return + } + s.maybeSendPackets() + } +} + +// send as many packets as are ready +func (s *packetManager) maybeSendPackets() { + for { + if len(s.outgoing) == 0 || len(s.incoming) == 0 { + debug("break! -- outgoing: %v; incoming: %v", + len(s.outgoing), len(s.incoming)) + break + } + out := s.outgoing[0] + in := s.incoming[0] + // debug("incoming: %v", ids(s.incoming)) + // debug("outgoing: %v", ids(s.outgoing)) + if in.orderID() == out.orderID() { + debug("Sending packet: %v", out.id()) + s.sender.sendPacket(out.(encoding.BinaryMarshaler)) + if s.alloc != nil { + // mark for reuse the slices allocated for this request + s.alloc.ReleasePages(in.orderID()) + } + // pop off heads + copy(s.incoming, s.incoming[1:]) // shift left + s.incoming[len(s.incoming)-1] = nil // clear last + s.incoming = s.incoming[:len(s.incoming)-1] // remove last + copy(s.outgoing, s.outgoing[1:]) // shift left + s.outgoing[len(s.outgoing)-1] = nil // clear last + s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last + } else { + break + } + } +} + +// func oids(o []orderedPacket) []uint32 { +// res := make([]uint32, 0, len(o)) +// for _, v := range o { +// res = append(res, v.orderId()) +// } +// return res +// } +// func ids(o []orderedPacket) []uint32 { +// res := make([]uint32, 0, len(o)) +// for _, v := range o { +// res = append(res, v.id()) +// } +// return res +// } + |