From 0ee29e31ddcc81f541de7459b0a5e40dfa552672 Mon Sep 17 00:00:00 2001 From: Quentin Dufour Date: Fri, 19 Nov 2021 19:54:49 +0100 Subject: Working on SFTP --- sftp/allocator.go | 101 +++ sftp/attrs.go | 90 ++ sftp/attrs_stubs.go | 11 + sftp/attrs_unix.go | 16 + sftp/client.go | 1936 +++++++++++++++++++++++++++++++++++++++++ sftp/conn.go | 189 ++++ sftp/debug.go | 9 + sftp/ls_formatting.go | 81 ++ sftp/ls_plan9.go | 21 + sftp/ls_stub.go | 11 + sftp/ls_unix.go | 23 + sftp/packet.go | 1276 +++++++++++++++++++++++++++ sftp/packet_manager.go | 221 +++++ sftp/packet_typing.go | 135 +++ sftp/pool.go | 79 ++ sftp/release.go | 5 + sftp/request-attrs.go | 63 ++ sftp/request-errors.go | 54 ++ sftp/request-interfaces.go | 121 +++ sftp/request-plan9.go | 34 + sftp/request-server.go | 304 +++++++ sftp/request-unix.go | 27 + sftp/request.go | 628 +++++++++++++ sftp/request_windows.go | 44 + sftp/server.go | 643 ++++++++++++++ sftp/server_statvfs_darwin.go | 21 + sftp/server_statvfs_impl.go | 29 + sftp/server_statvfs_linux.go | 22 + sftp/server_statvfs_plan9.go | 13 + sftp/server_statvfs_stubs.go | 15 + sftp/sftp.go | 258 ++++++ sftp/stat_plan9.go | 103 +++ sftp/stat_posix.go | 124 +++ sftp/syscall_fixed.go | 9 + sftp/syscall_good.go | 8 + 35 files changed, 6724 insertions(+) create mode 100644 sftp/allocator.go create mode 100644 sftp/attrs.go create mode 100644 sftp/attrs_stubs.go create mode 100644 sftp/attrs_unix.go create mode 100644 sftp/client.go create mode 100644 sftp/conn.go create mode 100644 sftp/debug.go create mode 100644 sftp/ls_formatting.go create mode 100644 sftp/ls_plan9.go create mode 100644 sftp/ls_stub.go create mode 100644 sftp/ls_unix.go create mode 100644 sftp/packet.go create mode 100644 sftp/packet_manager.go create mode 100644 sftp/packet_typing.go create mode 100644 sftp/pool.go create mode 100644 sftp/release.go create mode 100644 sftp/request-attrs.go create mode 100644 sftp/request-errors.go create mode 100644 sftp/request-interfaces.go create mode 100644 sftp/request-plan9.go create mode 100644 sftp/request-server.go create mode 100644 sftp/request-unix.go create mode 100644 sftp/request.go create mode 100644 sftp/request_windows.go create mode 100644 sftp/server.go create mode 100644 sftp/server_statvfs_darwin.go create mode 100644 sftp/server_statvfs_impl.go create mode 100644 sftp/server_statvfs_linux.go create mode 100644 sftp/server_statvfs_plan9.go create mode 100644 sftp/server_statvfs_stubs.go create mode 100644 sftp/sftp.go create mode 100644 sftp/stat_plan9.go create mode 100644 sftp/stat_posix.go create mode 100644 sftp/syscall_fixed.go create mode 100644 sftp/syscall_good.go (limited to 'sftp') diff --git a/sftp/allocator.go b/sftp/allocator.go new file mode 100644 index 0000000..fc1b6f0 --- /dev/null +++ b/sftp/allocator.go @@ -0,0 +1,101 @@ +package sftp + +/* + Imported from: https://github.com/pkg/sftp + */ + +import ( + "sync" +) + +type allocator struct { + sync.Mutex + available [][]byte + // map key is the request order + used map[uint32][][]byte +} + +func newAllocator() *allocator { + return &allocator{ + // micro optimization: initialize available pages with an initial capacity + available: make([][]byte, 0, SftpServerWorkerCount*2), + used: make(map[uint32][][]byte), + } +} + +// GetPage returns a previously allocated and unused []byte or create a new one. +// The slice have a fixed size = maxMsgLength, this value is suitable for both +// receiving new packets and reading the files to serve +func (a *allocator) GetPage(requestOrderID uint32) []byte { + a.Lock() + defer a.Unlock() + + var result []byte + + // get an available page and remove it from the available ones. + if len(a.available) > 0 { + truncLength := len(a.available) - 1 + result = a.available[truncLength] + + a.available[truncLength] = nil // clear out the internal pointer + a.available = a.available[:truncLength] // truncate the slice + } + + // no preallocated slice found, just allocate a new one + if result == nil { + result = make([]byte, maxMsgLength) + } + + // put result in used pages + a.used[requestOrderID] = append(a.used[requestOrderID], result) + + return result +} + +// ReleasePages marks unused all pages in use for the given requestID +func (a *allocator) ReleasePages(requestOrderID uint32) { + a.Lock() + defer a.Unlock() + + if used := a.used[requestOrderID]; len(used) > 0 { + a.available = append(a.available, used...) + } + delete(a.used, requestOrderID) +} + +// Free removes all the used and available pages. +// Call this method when the allocator is not needed anymore +func (a *allocator) Free() { + a.Lock() + defer a.Unlock() + + a.available = nil + a.used = make(map[uint32][][]byte) +} + +func (a *allocator) countUsedPages() int { + a.Lock() + defer a.Unlock() + + num := 0 + for _, p := range a.used { + num += len(p) + } + return num +} + +func (a *allocator) countAvailablePages() int { + a.Lock() + defer a.Unlock() + + return len(a.available) +} + +func (a *allocator) isRequestOrderIDUsed(requestOrderID uint32) bool { + a.Lock() + defer a.Unlock() + + _, ok := a.used[requestOrderID] + return ok +} + diff --git a/sftp/attrs.go b/sftp/attrs.go new file mode 100644 index 0000000..2bb2d57 --- /dev/null +++ b/sftp/attrs.go @@ -0,0 +1,90 @@ +package sftp + +// ssh_FXP_ATTRS support +// see http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-5 + +import ( + "os" + "time" +) + +const ( + sshFileXferAttrSize = 0x00000001 + sshFileXferAttrUIDGID = 0x00000002 + sshFileXferAttrPermissions = 0x00000004 + sshFileXferAttrACmodTime = 0x00000008 + sshFileXferAttrExtended = 0x80000000 + + sshFileXferAttrAll = sshFileXferAttrSize | sshFileXferAttrUIDGID | sshFileXferAttrPermissions | + sshFileXferAttrACmodTime | sshFileXferAttrExtended +) + +// fileInfo is an artificial type designed to satisfy os.FileInfo. +type fileInfo struct { + name string + stat *FileStat +} + +// Name returns the base name of the file. +func (fi *fileInfo) Name() string { return fi.name } + +// Size returns the length in bytes for regular files; system-dependent for others. +func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) } + +// Mode returns file mode bits. +func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) } + +// ModTime returns the last modification time of the file. +func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) } + +// IsDir returns true if the file is a directory. +func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() } + +func (fi *fileInfo) Sys() interface{} { return fi.stat } + +// FileStat holds the original unmarshalled values from a call to READDIR or +// *STAT. It is exported for the purposes of accessing the raw values via +// os.FileInfo.Sys(). It is also used server side to store the unmarshalled +// values for SetStat. +type FileStat struct { + Size uint64 + Mode uint32 + Mtime uint32 + Atime uint32 + UID uint32 + GID uint32 + Extended []StatExtended +} + +// StatExtended contains additional, extended information for a FileStat. +type StatExtended struct { + ExtType string + ExtData string +} + +func fileInfoFromStat(stat *FileStat, name string) os.FileInfo { + return &fileInfo{ + name: name, + stat: stat, + } +} + +func fileStatFromInfo(fi os.FileInfo) (uint32, *FileStat) { + mtime := fi.ModTime().Unix() + atime := mtime + var flags uint32 = sshFileXferAttrSize | + sshFileXferAttrPermissions | + sshFileXferAttrACmodTime + + fileStat := &FileStat{ + Size: uint64(fi.Size()), + Mode: fromFileMode(fi.Mode()), + Mtime: uint32(mtime), + Atime: uint32(atime), + } + + // os specific file stat decoding + fileStatFromInfoOs(fi, &flags, fileStat) + + return flags, fileStat +} diff --git a/sftp/attrs_stubs.go b/sftp/attrs_stubs.go new file mode 100644 index 0000000..c01f336 --- /dev/null +++ b/sftp/attrs_stubs.go @@ -0,0 +1,11 @@ +// +build plan9 windows android + +package sftp + +import ( + "os" +) + +func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) { + // todo +} diff --git a/sftp/attrs_unix.go b/sftp/attrs_unix.go new file mode 100644 index 0000000..d1f4452 --- /dev/null +++ b/sftp/attrs_unix.go @@ -0,0 +1,16 @@ +// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js + +package sftp + +import ( + "os" + "syscall" +) + +func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) { + if statt, ok := fi.Sys().(*syscall.Stat_t); ok { + *flags |= sshFileXferAttrUIDGID + fileStat.UID = statt.Uid + fileStat.GID = statt.Gid + } +} diff --git a/sftp/client.go b/sftp/client.go new file mode 100644 index 0000000..ce62286 --- /dev/null +++ b/sftp/client.go @@ -0,0 +1,1936 @@ +package sftp + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "math" + "os" + "path" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/kr/fs" + "golang.org/x/crypto/ssh" +) + +var ( + // ErrInternalInconsistency indicates the packets sent and the data queued to be + // written to the file don't match up. It is an unusual error and usually is + // caused by bad behavior server side or connection issues. The error is + // limited in scope to the call where it happened, the client object is still + // OK to use as long as the connection is still open. + ErrInternalInconsistency = errors.New("internal inconsistency") + // InternalInconsistency alias for ErrInternalInconsistency. + // + // Deprecated: please use ErrInternalInconsistency + InternalInconsistency = ErrInternalInconsistency +) + +// A ClientOption is a function which applies configuration to a Client. +type ClientOption func(*Client) error + +// MaxPacketChecked sets the maximum size of the payload, measured in bytes. +// This option only accepts sizes servers should support, ie. <= 32768 bytes. +// +// If you get the error "failed to send packet header: EOF" when copying a +// large file, try lowering this number. +// +// The default packet size is 32768 bytes. +func MaxPacketChecked(size int) ClientOption { + return func(c *Client) error { + if size < 1 { + return errors.New("size must be greater or equal to 1") + } + if size > 32768 { + return errors.New("sizes larger than 32KB might not work with all servers") + } + c.maxPacket = size + return nil + } +} + +// MaxPacketUnchecked sets the maximum size of the payload, measured in bytes. +// It accepts sizes larger than the 32768 bytes all servers should support. +// Only use a setting higher than 32768 if your application always connects to +// the same server or after sufficiently broad testing. +// +// If you get the error "failed to send packet header: EOF" when copying a +// large file, try lowering this number. +// +// The default packet size is 32768 bytes. +func MaxPacketUnchecked(size int) ClientOption { + return func(c *Client) error { + if size < 1 { + return errors.New("size must be greater or equal to 1") + } + c.maxPacket = size + return nil + } +} + +// MaxPacket sets the maximum size of the payload, measured in bytes. +// This option only accepts sizes servers should support, ie. <= 32768 bytes. +// This is a synonym for MaxPacketChecked that provides backward compatibility. +// +// If you get the error "failed to send packet header: EOF" when copying a +// large file, try lowering this number. +// +// The default packet size is 32768 bytes. +func MaxPacket(size int) ClientOption { + return MaxPacketChecked(size) +} + +// MaxConcurrentRequestsPerFile sets the maximum concurrent requests allowed for a single file. +// +// The default maximum concurrent requests is 64. +func MaxConcurrentRequestsPerFile(n int) ClientOption { + return func(c *Client) error { + if n < 1 { + return errors.New("n must be greater or equal to 1") + } + c.maxConcurrentRequests = n + return nil + } +} + +// UseConcurrentWrites allows the Client to perform concurrent Writes. +// +// Using concurrency while doing writes, requires special consideration. +// A write to a later offset in a file after an error, +// could end up with a file length longer than what was successfully written. +// +// When using this option, if you receive an error during `io.Copy` or `io.WriteTo`, +// you may need to `Truncate` the target Writer to avoid “holes” in the data written. +func UseConcurrentWrites(value bool) ClientOption { + return func(c *Client) error { + c.useConcurrentWrites = value + return nil + } +} + +// UseConcurrentReads allows the Client to perform concurrent Reads. +// +// Concurrent reads are generally safe to use and not using them will degrade +// performance, so this option is enabled by default. +// +// When enabled, WriteTo will use Stat/Fstat to get the file size and determines +// how many concurrent workers to use. +// Some "read once" servers will delete the file if they receive a stat call on an +// open file and then the download will fail. +// Disabling concurrent reads you will be able to download files from these servers. +// If concurrent reads are disabled, the UseFstat option is ignored. +func UseConcurrentReads(value bool) ClientOption { + return func(c *Client) error { + c.disableConcurrentReads = !value + return nil + } +} + +// UseFstat sets whether to use Fstat or Stat when File.WriteTo is called +// (usually when copying files). +// Some servers limit the amount of open files and calling Stat after opening +// the file will throw an error From the server. Setting this flag will call +// Fstat instead of Stat which is suppose to be called on an open file handle. +// +// It has been found that that with IBM Sterling SFTP servers which have +// "extractability" level set to 1 which means only 1 file can be opened at +// any given time. +// +// If the server you are working with still has an issue with both Stat and +// Fstat calls you can always open a file and read it until the end. +// +// Another reason to read the file until its end and Fstat doesn't work is +// that in some servers, reading a full file will automatically delete the +// file as some of these mainframes map the file to a message in a queue. +// Once the file has been read it will get deleted. +func UseFstat(value bool) ClientOption { + return func(c *Client) error { + c.useFstat = value + return nil + } +} + +// Client represents an SFTP session on a *ssh.ClientConn SSH connection. +// Multiple Clients can be active on a single SSH connection, and a Client +// may be called concurrently from multiple Goroutines. +// +// Client implements the github.com/kr/fs.FileSystem interface. +type Client struct { + clientConn + + ext map[string]string // Extensions (name -> data). + + maxPacket int // max packet size read or written. + maxConcurrentRequests int + nextid uint32 + + // write concurrency is… error prone. + // Default behavior should be to not use it. + useConcurrentWrites bool + useFstat bool + disableConcurrentReads bool +} + +// NewClient creates a new SFTP client on conn, using zero or more option +// functions. +func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { + s, err := conn.NewSession() + if err != nil { + return nil, err + } + if err := s.RequestSubsystem("sftp"); err != nil { + return nil, err + } + pw, err := s.StdinPipe() + if err != nil { + return nil, err + } + pr, err := s.StdoutPipe() + if err != nil { + return nil, err + } + + return NewClientPipe(pr, pw, opts...) +} + +// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. +// This can be used for connecting to an SFTP server over TCP/TLS or by using +// the system's ssh client program (e.g. via exec.Command). +func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { + sftp := &Client{ + clientConn: clientConn{ + conn: conn{ + Reader: rd, + WriteCloser: wr, + }, + inflight: make(map[uint32]chan<- result), + closed: make(chan struct{}), + }, + + ext: make(map[string]string), + + maxPacket: 1 << 15, + maxConcurrentRequests: 64, + } + + for _, opt := range opts { + if err := opt(sftp); err != nil { + wr.Close() + return nil, err + } + } + + if err := sftp.sendInit(); err != nil { + wr.Close() + return nil, err + } + if err := sftp.recvVersion(); err != nil { + wr.Close() + return nil, err + } + + sftp.clientConn.wg.Add(1) + go sftp.loop() + + return sftp, nil +} + +// Create creates the named file mode 0666 (before umask), truncating it if it +// already exists. If successful, methods on the returned File can be used for +// I/O; the associated file descriptor has mode O_RDWR. If you need more +// control over the flags/mode used to open the file see client.OpenFile. +// +// Note that some SFTP servers (eg. AWS Transfer) do not support opening files +// read/write at the same time. For those services you will need to use +// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`. +func (c *Client) Create(path string) (*File, error) { + return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) +} + +const sftpProtocolVersion = 3 // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 + +func (c *Client) sendInit() error { + return c.clientConn.conn.sendPacket(&sshFxInitPacket{ + Version: sftpProtocolVersion, // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 + }) +} + +// returns the next value of c.nextid +func (c *Client) nextID() uint32 { + return atomic.AddUint32(&c.nextid, 1) +} + +func (c *Client) recvVersion() error { + typ, data, err := c.recvPacket(0) + if err != nil { + return err + } + if typ != sshFxpVersion { + return &unexpectedPacketErr{sshFxpVersion, typ} + } + + version, data, err := unmarshalUint32Safe(data) + if err != nil { + return err + } + if version != sftpProtocolVersion { + return &unexpectedVersionErr{sftpProtocolVersion, version} + } + + for len(data) > 0 { + var ext extensionPair + ext, data, err = unmarshalExtensionPair(data) + if err != nil { + return err + } + c.ext[ext.Name] = ext.Data + } + + return nil +} + +// HasExtension checks whether the server supports a named extension. +// +// The first return value is the extension data reported by the server +// (typically a version number). +func (c *Client) HasExtension(name string) (string, bool) { + data, ok := c.ext[name] + return data, ok +} + +// Walk returns a new Walker rooted at root. +func (c *Client) Walk(root string) *fs.Walker { + return fs.WalkFS(root, c) +} + +// ReadDir reads the directory named by dirname and returns a list of +// directory entries. +func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { + handle, err := c.opendir(p) + if err != nil { + return nil, err + } + defer c.close(handle) // this has to defer earlier than the lock below + var attrs []os.FileInfo + var done = false + for !done { + id := c.nextID() + typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{ + ID: id, + Handle: handle, + }) + if err1 != nil { + err = err1 + done = true + break + } + switch typ { + case sshFxpName: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + count, data := unmarshalUint32(data) + for i := uint32(0); i < count; i++ { + var filename string + filename, data = unmarshalString(data) + _, data = unmarshalString(data) // discard longname + var attr *FileStat + attr, data = unmarshalAttrs(data) + if filename == "." || filename == ".." { + continue + } + attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename))) + } + case sshFxpStatus: + // TODO(dfc) scope warning! + err = normaliseError(unmarshalStatus(id, data)) + done = true + default: + return nil, unimplementedPacketErr(typ) + } + } + if err == io.EOF { + err = nil + } + return attrs, err +} + +func (c *Client) opendir(path string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{ + ID: id, + Path: path, + }) + if err != nil { + return "", err + } + switch typ { + case sshFxpHandle: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIDErr{id, sid} + } + handle, _ := unmarshalString(data) + return handle, nil + case sshFxpStatus: + return "", normaliseError(unmarshalStatus(id, data)) + default: + return "", unimplementedPacketErr(typ) + } +} + +// Stat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file. +func (c *Client) Stat(p string) (os.FileInfo, error) { + fs, err := c.stat(p) + if err != nil { + return nil, err + } + return fileInfoFromStat(fs, path.Base(p)), nil +} + +// Lstat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link. +func (c *Client) Lstat(p string) (os.FileInfo, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{ + ID: id, + Path: p, + }) + if err != nil { + return nil, err + } + switch typ { + case sshFxpAttrs: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return fileInfoFromStat(attr, path.Base(p)), nil + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// ReadLink reads the target of a symbolic link. +func (c *Client) ReadLink(p string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{ + ID: id, + Path: p, + }) + if err != nil { + return "", err + } + switch typ { + case sshFxpName: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIDErr{id, sid} + } + count, data := unmarshalUint32(data) + if count != 1 { + return "", unexpectedCount(1, count) + } + filename, _ := unmarshalString(data) // ignore dummy attributes + return filename, nil + case sshFxpStatus: + return "", normaliseError(unmarshalStatus(id, data)) + default: + return "", unimplementedPacketErr(typ) + } +} + +// Link creates a hard link at 'newname', pointing at the same inode as 'oldname' +func (c *Client) Link(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{ + ID: id, + Oldpath: oldname, + Newpath: newname, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Symlink creates a symbolic link at 'newname', pointing at target 'oldname' +func (c *Client) Symlink(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{ + ID: id, + Linkpath: newname, + Targetpath: oldname, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{ + ID: id, + Handle: handle, + Flags: flags, + Attrs: attrs, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// setstat is a convience wrapper to allow for changing of various parts of the file descriptor. +func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{ + ID: id, + Path: path, + Flags: flags, + Attrs: attrs, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Chtimes changes the access and modification times of the named file. +func (c *Client) Chtimes(path string, atime time.Time, mtime time.Time) error { + type times struct { + Atime uint32 + Mtime uint32 + } + attrs := times{uint32(atime.Unix()), uint32(mtime.Unix())} + return c.setstat(path, sshFileXferAttrACmodTime, attrs) +} + +// Chown changes the user and group owners of the named file. +func (c *Client) Chown(path string, uid, gid int) error { + type owner struct { + UID uint32 + GID uint32 + } + attrs := owner{uint32(uid), uint32(gid)} + return c.setstat(path, sshFileXferAttrUIDGID, attrs) +} + +// Chmod changes the permissions of the named file. +// +// Chmod does not apply a umask, because even retrieving the umask is not +// possible in a portable way without causing a race condition. Callers +// should mask off umask bits, if desired. +func (c *Client) Chmod(path string, mode os.FileMode) error { + return c.setstat(path, sshFileXferAttrPermissions, toChmodPerm(mode)) +} + +// Truncate sets the size of the named file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +func (c *Client) Truncate(path string, size int64) error { + return c.setstat(path, sshFileXferAttrSize, uint64(size)) +} + +// Open opens the named file for reading. If successful, methods on the +// returned file can be used for reading; the associated file descriptor +// has mode O_RDONLY. +func (c *Client) Open(path string) (*File, error) { + return c.open(path, flags(os.O_RDONLY)) +} + +// OpenFile is the generalized open call; most users will use Open or +// Create instead. It opens the named file with specified flag (O_RDONLY +// etc.). If successful, methods on the returned File can be used for I/O. +func (c *Client) OpenFile(path string, f int) (*File, error) { + return c.open(path, flags(f)) +} + +func (c *Client) open(path string, pflags uint32) (*File, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{ + ID: id, + Path: path, + Pflags: pflags, + }) + if err != nil { + return nil, err + } + switch typ { + case sshFxpHandle: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + handle, _ := unmarshalString(data) + return &File{c: c, path: path, handle: handle}, nil + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// close closes a handle handle previously returned in the response +// to SSH_FXP_OPEN or SSH_FXP_OPENDIR. The handle becomes invalid +// immediately after this request has been sent. +func (c *Client) close(handle string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{ + ID: id, + Handle: handle, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) stat(path string) (*FileStat, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{ + ID: id, + Path: path, + }) + if err != nil { + return nil, err + } + switch typ { + case sshFxpAttrs: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return attr, nil + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +func (c *Client) fstat(handle string) (*FileStat, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{ + ID: id, + Handle: handle, + }) + if err != nil { + return nil, err + } + switch typ { + case sshFxpAttrs: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _ := unmarshalAttrs(data) + return attr, nil + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// StatVFS retrieves VFS statistics from a remote host. +// +// It implements the statvfs@openssh.com SSH_FXP_EXTENDED feature +// from http://www.opensource.apple.com/source/OpenSSH/OpenSSH-175/openssh/PROTOCOL?txt. +func (c *Client) StatVFS(path string) (*StatVFS, error) { + // send the StatVFS packet to the server + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{ + ID: id, + Path: path, + }) + if err != nil { + return nil, err + } + + switch typ { + // server responded with valid data + case sshFxpExtendedReply: + var response StatVFS + err = binary.Read(bytes.NewReader(data), binary.BigEndian, &response) + if err != nil { + return nil, errors.New("can not parse reply") + } + + return &response, nil + + // the resquest failed + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + + default: + return nil, unimplementedPacketErr(typ) + } +} + +// Join joins any number of path elements into a single path, adding a +// separating slash if necessary. The result is Cleaned; in particular, all +// empty strings are ignored. +func (c *Client) Join(elem ...string) string { return path.Join(elem...) } + +// Remove removes the specified file or directory. An error will be returned if no +// file or directory with the specified path exists, or if the specified directory +// is not empty. +func (c *Client) Remove(path string) error { + err := c.removeFile(path) + // some servers, *cough* osx *cough*, return EPERM, not ENODIR. + // serv-u returns ssh_FX_FILE_IS_A_DIRECTORY + // EPERM is converted to os.ErrPermission so it is not a StatusError + if err, ok := err.(*StatusError); ok { + switch err.Code { + case sshFxFailure, sshFxFileIsADirectory: + return c.RemoveDirectory(path) + } + } + if os.IsPermission(err) { + return c.RemoveDirectory(path) + } + return err +} + +func (c *Client) removeFile(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{ + ID: id, + Filename: path, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// RemoveDirectory removes a directory path. +func (c *Client) RemoveDirectory(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{ + ID: id, + Path: path, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Rename renames a file. +func (c *Client) Rename(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{ + ID: id, + Oldpath: oldname, + Newpath: newname, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// PosixRename renames a file using the posix-rename@openssh.com extension +// which will replace newname if it already exists. +func (c *Client) PosixRename(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{ + ID: id, + Oldpath: oldname, + Newpath: newname, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// RealPath can be used to have the server canonicalize any given path name to an absolute path. +// +// This is useful for converting path names containing ".." components, +// or relative pathnames without a leading slash into absolute paths. +func (c *Client) RealPath(path string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{ + ID: id, + Path: path, + }) + if err != nil { + return "", err + } + switch typ { + case sshFxpName: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIDErr{id, sid} + } + count, data := unmarshalUint32(data) + if count != 1 { + return "", unexpectedCount(1, count) + } + filename, _ := unmarshalString(data) // ignore attributes + return filename, nil + case sshFxpStatus: + return "", normaliseError(unmarshalStatus(id, data)) + default: + return "", unimplementedPacketErr(typ) + } +} + +// Getwd returns the current working directory of the server. Operations +// involving relative paths will be based at this location. +func (c *Client) Getwd() (string, error) { + return c.RealPath(".") +} + +// Mkdir creates the specified directory. An error will be returned if a file or +// directory with the specified path already exists, or if the directory's +// parent folder does not exist (the method cannot create complete paths). +func (c *Client) Mkdir(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{ + ID: id, + Path: path, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// MkdirAll creates a directory named path, along with any necessary parents, +// and returns nil, or else returns an error. +// If path is already a directory, MkdirAll does nothing and returns nil. +// If path contains a regular file, an error is returned +func (c *Client) MkdirAll(path string) error { + // Most of this code mimics https://golang.org/src/os/path.go?s=514:561#L13 + // Fast path: if we can tell whether path is a directory or file, stop with success or error. + dir, err := c.Stat(path) + if err == nil { + if dir.IsDir() { + return nil + } + return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR} + } + + // Slow path: make sure parent exists and then call Mkdir for path. + i := len(path) + for i > 0 && path[i-1] == '/' { // Skip trailing path separator. + i-- + } + + j := i + for j > 0 && path[j-1] != '/' { // Scan backward over element. + j-- + } + + if j > 1 { + // Create parent + err = c.MkdirAll(path[0 : j-1]) + if err != nil { + return err + } + } + + // Parent now exists; invoke Mkdir and use its result. + err = c.Mkdir(path) + if err != nil { + // Handle arguments like "foo/." by + // double-checking that directory doesn't exist. + dir, err1 := c.Lstat(path) + if err1 == nil && dir.IsDir() { + return nil + } + return err + } + return nil +} + +// File represents a remote file. +type File struct { + c *Client + path string + handle string + + mu sync.Mutex + offset int64 // current offset within remote file +} + +// Close closes the File, rendering it unusable for I/O. It returns an +// error, if any. +func (f *File) Close() error { + return f.c.close(f.handle) +} + +// Name returns the name of the file as presented to Open or Create. +func (f *File) Name() string { + return f.path +} + +// Read reads up to len(b) bytes from the File. It returns the number of bytes +// read and an error, if any. Read follows io.Reader semantics, so when Read +// encounters an error or EOF condition after successfully reading n > 0 bytes, +// it returns the number of bytes read. +// +// To maximise throughput for transferring the entire file (especially +// over high latency links) it is recommended to use WriteTo rather +// than calling Read multiple times. io.Copy will do this +// automatically. +func (f *File) Read(b []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + n, err := f.ReadAt(b, f.offset) + f.offset += int64(n) + return n, err +} + +// readChunkAt attempts to read the whole entire length of the buffer from the file starting at the offset. +// It will continue progressively reading into the buffer until it fills the whole buffer, or an error occurs. +func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) { + for err == nil && n < len(b) { + id := f.c.nextID() + typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off) + uint64(n), + Len: uint32(len(b) - n), + }) + if err != nil { + return n, err + } + + switch typ { + case sshFxpStatus: + return n, normaliseError(unmarshalStatus(id, data)) + + case sshFxpData: + sid, data := unmarshalUint32(data) + if id != sid { + return n, &unexpectedIDErr{id, sid} + } + + l, data := unmarshalUint32(data) + n += copy(b[n:], data[:l]) + + default: + return n, unimplementedPacketErr(typ) + } + } + + return +} + +func (f *File) readAtSequential(b []byte, off int64) (read int, err error) { + for read < len(b) { + rb := b[read:] + if len(rb) > f.c.maxPacket { + rb = rb[:f.c.maxPacket] + } + n, err := f.readChunkAt(nil, rb, off+int64(read)) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") + } + if n > 0 { + read += n + } + if err != nil { + if errors.Is(err, io.EOF) { + return read, nil // return nil explicitly. + } + return read, err + } + } + return read, nil +} + +// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns +// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, +// so the file offset is not altered during the read. +func (f *File) ReadAt(b []byte, off int64) (int, error) { + if len(b) <= f.c.maxPacket { + // This should be able to be serviced with 1/2 requests. + // So, just do it directly. + return f.readChunkAt(nil, b, off) + } + + if f.c.disableConcurrentReads { + return f.readAtSequential(b, off) + } + + // Split the read into multiple maxPacket-sized concurrent reads bounded by maxConcurrentRequests. + // This allows writes with a suitably large buffer to transfer data at a much faster rate + // by overlapping round trip times. + + cancel := make(chan struct{}) + + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + resPool := newResChanPool(concurrency) + + type work struct { + id uint32 + res chan result + + b []byte + off int64 + } + workCh := make(chan work) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + b := b + offset := off + chunkSize := f.c.maxPacket + + for len(b) > 0 { + rb := b + if len(rb) > chunkSize { + rb = rb[:chunkSize] + } + + id := f.c.nextID() + res := resPool.Get() + + f.c.dispatchRequest(res, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(offset), + Len: uint32(chunkSize), + }) + + select { + case workCh <- work{id, res, rb, offset}: + case <-cancel: + return + } + + offset += int64(len(rb)) + b = b[len(rb):] + } + }() + + type rErr struct { + off int64 + err error + } + errCh := make(chan rErr) + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and then performs the Read into its buffer from its respective offset. + go func() { + defer wg.Done() + + for packet := range workCh { + var n int + + s := <-packet.res + resPool.Put(packet.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(packet.id, s.data)) + + case sshFxpData: + sid, data := unmarshalUint32(s.data) + if packet.id != sid { + err = &unexpectedIDErr{packet.id, sid} + + } else { + l, data := unmarshalUint32(data) + n = copy(packet.b, data[:l]) + + // For normal disk files, it is guaranteed that this will read + // the specified number of bytes, or up to end of file. + // This implies, if we have a short read, that means EOF. + if n < len(packet.b) { + err = io.EOF + } + } + + default: + err = unimplementedPacketErr(s.typ) + } + } + + if err != nil { + // return the offset as the start + how much we read before the error. + errCh <- rErr{packet.off + int64(n), err} + return + } + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: collect all the results into a relevant return: the earliest offset to return an error. + firstErr := rErr{math.MaxInt64, nil} + for rErr := range errCh { + if rErr.off <= firstErr.off { + firstErr = rErr + } + + select { + case <-cancel: + default: + // stop any more work from being distributed. (Just in case.) + close(cancel) + } + } + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off > our starting offset. + return int(firstErr.off - off), firstErr.err + } + + // As per spec for io.ReaderAt, we return nil error if and only if we read everything. + return len(b), nil +} + +// writeToSequential implements WriteTo, but works sequentially with no parallelism. +func (f *File) writeToSequential(w io.Writer) (written int64, err error) { + b := make([]byte, f.c.maxPacket) + ch := make(chan result, 1) // reusable channel + + for { + n, err := f.readChunkAt(ch, b, f.offset) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") + } + + if n > 0 { + f.offset += int64(n) + + m, err2 := w.Write(b[:n]) + written += int64(m) + + if err == nil { + err = err2 + } + } + + if err != nil { + if err == io.EOF { + return written, nil // return nil explicitly. + } + + return written, err + } + } +} + +// WriteTo writes the file to the given Writer. +// The return value is the number of bytes written. +// Any error encountered during the write is also returned. +// +// This method is preferred over calling Read multiple times +// to maximise throughput for transferring the entire file, +// especially over high latency links. +func (f *File) WriteTo(w io.Writer) (written int64, err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.c.disableConcurrentReads { + return f.writeToSequential(w) + } + + // For concurrency, we want to guess how many concurrent workers we should use. + var fileStat *FileStat + if f.c.useFstat { + fileStat, err = f.c.fstat(f.handle) + } else { + fileStat, err = f.c.stat(f.path) + } + if err != nil { + return 0, err + } + + fileSize := fileStat.Size + if fileSize <= uint64(f.c.maxPacket) || !isRegular(fileStat.Mode) { + // only regular files are guaranteed to return (full read) xor (partial read, next error) + return f.writeToSequential(w) + } + + concurrency64 := fileSize/uint64(f.c.maxPacket) + 1 // a bad guess, but better than no guess + if concurrency64 > uint64(f.c.maxConcurrentRequests) || concurrency64 < 1 { + concurrency64 = uint64(f.c.maxConcurrentRequests) + } + // Now that concurrency64 is saturated to an int value, we know this assignment cannot possibly overflow. + concurrency := int(concurrency64) + + chunkSize := f.c.maxPacket + pool := newBufPool(concurrency, chunkSize) + resPool := newResChanPool(concurrency) + + cancel := make(chan struct{}) + var wg sync.WaitGroup + defer func() { + // Once the writing Reduce phase has ended, all the feed work needs to unconditionally stop. + close(cancel) + + // We want to wait until all outstanding goroutines with an `f` or `f.c` reference have completed. + // Just to be sure we don’t orphan any goroutines any hanging references. + wg.Wait() + }() + + type writeWork struct { + b []byte + off int64 + err error + + next chan writeWork + } + writeCh := make(chan writeWork) + + type readWork struct { + id uint32 + res chan result + off int64 + + cur, next chan writeWork + } + readCh := make(chan readWork) + + // Slice: hand out chunks of work on demand, with a `cur` and `next` channel built-in for sequencing. + go func() { + defer close(readCh) + + off := f.offset + + cur := writeCh + for { + id := f.c.nextID() + res := resPool.Get() + + next := make(chan writeWork) + readWork := readWork{ + id: id, + res: res, + off: off, + + cur: cur, + next: next, + } + + f.c.dispatchRequest(res, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off), + Len: uint32(chunkSize), + }) + + select { + case readCh <- readWork: + case <-cancel: + return + } + + off += int64(chunkSize) + cur = next + } + }() + + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets readWork, and does the Read into a buffer at the given offset. + go func() { + defer wg.Done() + + for readWork := range readCh { + var b []byte + var n int + + s := <-readWork.res + resPool.Put(readWork.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(readWork.id, s.data)) + + case sshFxpData: + sid, data := unmarshalUint32(s.data) + if readWork.id != sid { + err = &unexpectedIDErr{readWork.id, sid} + + } else { + l, data := unmarshalUint32(data) + b = pool.Get()[:l] + n = copy(b, data[:l]) + b = b[:n] + } + + default: + err = unimplementedPacketErr(s.typ) + } + } + + writeWork := writeWork{ + b: b, + off: readWork.off, + err: err, + + next: readWork.next, + } + + select { + case readWork.cur <- writeWork: + case <-cancel: + return + } + + if err != nil { + return + } + } + }() + } + + // Reduce: serialize the results from the reads into sequential writes. + cur := writeCh + for { + packet, ok := <-cur + if !ok { + return written, errors.New("sftp.File.WriteTo: unexpectedly closed channel") + } + + // Because writes are serialized, this will always be the last successfully read byte. + f.offset = packet.off + int64(len(packet.b)) + + if len(packet.b) > 0 { + n, err := w.Write(packet.b) + written += int64(n) + if err != nil { + return written, err + } + } + + if packet.err != nil { + if packet.err == io.EOF { + return written, nil + } + + return written, packet.err + } + + pool.Put(packet.b) + cur = packet.next + } +} + +// Stat returns the FileInfo structure describing file. If there is an +// error. +func (f *File) Stat() (os.FileInfo, error) { + fs, err := f.c.fstat(f.handle) + if err != nil { + return nil, err + } + return fileInfoFromStat(fs, path.Base(f.path)), nil +} + +// Write writes len(b) bytes to the File. It returns the number of bytes +// written and an error, if any. Write returns a non-nil error when n != +// len(b). +// +// To maximise throughput for transferring the entire file (especially +// over high latency links) it is recommended to use ReadFrom rather +// than calling Write multiple times. io.Copy will do this +// automatically. +func (f *File) Write(b []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + n, err := f.WriteAt(b, f.offset) + f.offset += int64(n) + return n, err +} + +func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) { + typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{ + ID: f.c.nextID(), + Handle: f.handle, + Offset: uint64(off), + Length: uint32(len(b)), + Data: b, + }) + if err != nil { + return 0, err + } + + switch typ { + case sshFxpStatus: + id, _ := unmarshalUint32(data) + err := normaliseError(unmarshalStatus(id, data)) + if err != nil { + return 0, err + } + + default: + return 0, unimplementedPacketErr(typ) + } + + return len(b), nil +} + +// writeAtConcurrent implements WriterAt, but works concurrently rather than sequentially. +func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { + // Split the write into multiple maxPacket sized concurrent writes + // bounded by maxConcurrentRequests. This allows writes with a suitably + // large buffer to transfer data at a much faster rate due to + // overlapping round trip times. + + cancel := make(chan struct{}) + + type work struct { + b []byte + off int64 + } + workCh := make(chan work) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + var read int + chunkSize := f.c.maxPacket + + for read < len(b) { + wb := b[read:] + if len(wb) > chunkSize { + wb = wb[:chunkSize] + } + + select { + case workCh <- work{wb, off + int64(read)}: + case <-cancel: + return + } + + read += len(wb) + } + }() + + type wErr struct { + off int64 + err error + } + errCh := make(chan wErr) + + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and does the Write from each buffer to its respective offset. + go func() { + defer wg.Done() + + ch := make(chan result, 1) // reusable channel per mapper. + + for packet := range workCh { + n, err := f.writeChunkAt(ch, packet.b, packet.off) + if err != nil { + // return the offset as the start + how much we wrote before the error. + errCh <- wErr{packet.off + int64(n), err} + } + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: collect all the results into a relevant return: the earliest offset to return an error. + firstErr := wErr{math.MaxInt64, nil} + for wErr := range errCh { + if wErr.off <= firstErr.off { + firstErr = wErr + } + + select { + case <-cancel: + default: + // stop any more work from being distributed. (Just in case.) + close(cancel) + } + } + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off >= our starting offset. + return int(firstErr.off - off), firstErr.err + } + + return len(b), nil +} + +// WriteAt writes up to len(b) byte to the File at a given offset `off`. It returns +// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, +// so the file offset is not altered during the write. +func (f *File) WriteAt(b []byte, off int64) (written int, err error) { + if len(b) <= f.c.maxPacket { + // We can do this in one write. + return f.writeChunkAt(nil, b, off) + } + + if f.c.useConcurrentWrites { + return f.writeAtConcurrent(b, off) + } + + ch := make(chan result, 1) // reusable channel + + chunkSize := f.c.maxPacket + + for written < len(b) { + wb := b[written:] + if len(wb) > chunkSize { + wb = wb[:chunkSize] + } + + n, err := f.writeChunkAt(ch, wb, off+int64(written)) + if n > 0 { + written += n + } + + if err != nil { + return written, err + } + } + + return len(b), nil +} + +// ReadFromWithConcurrency implements ReaderFrom, +// but uses the given concurrency to issue multiple requests at the same time. +// +// Giving a concurrency of less than one will default to the Client’s max concurrency. +// +// Otherwise, the given concurrency will be capped by the Client's max concurrency. +func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { + // Split the write into multiple maxPacket sized concurrent writes. + // This allows writes with a suitably large reader + // to transfer data at a much faster rate due to overlapping round trip times. + + cancel := make(chan struct{}) + + type work struct { + b []byte + n int + off int64 + } + workCh := make(chan work) + + type rwErr struct { + off int64 + err error + } + errCh := make(chan rwErr) + + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + pool := newBufPool(concurrency, f.c.maxPacket) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + off := f.offset + + for { + b := pool.Get() + + n, err := r.Read(b) + if n > 0 { + read += int64(n) + + select { + case workCh <- work{b, n, off}: + // We need the pool.Put(b) to put the whole slice, not just trunced. + case <-cancel: + return + } + + off += int64(n) + } + + if err != nil { + if err != io.EOF { + errCh <- rwErr{off, err} + } + return + } + } + }() + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and does the Write from each buffer to its respective offset. + go func() { + defer wg.Done() + + ch := make(chan result, 1) // reusable channel per mapper. + + for packet := range workCh { + n, err := f.writeChunkAt(ch, packet.b[:packet.n], packet.off) + if err != nil { + // return the offset as the start + how much we wrote before the error. + errCh <- rwErr{packet.off + int64(n), err} + } + pool.Put(packet.b) + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: Collect all the results into a relevant return: the earliest offset to return an error. + firstErr := rwErr{math.MaxInt64, nil} + for rwErr := range errCh { + if rwErr.off <= firstErr.off { + firstErr = rwErr + } + + select { + case <-cancel: + default: + // stop any more work from being distributed. + close(cancel) + } + } + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off is a valid offset. + // + // firstErr.off will then be the lesser of: + // * the offset of the first error from writing, + // * the last successfully read offset. + // + // This could be less than the last successfully written offset, + // which is the whole reason for the UseConcurrentWrites() ClientOption. + // + // Callers are responsible for truncating any SFTP files to a safe length. + f.offset = firstErr.off + + // ReadFrom is defined to return the read bytes, regardless of any writer errors. + return read, firstErr.err + } + + f.offset += read + return read, nil +} + +// ReadFrom reads data from r until EOF and writes it to the file. The return +// value is the number of bytes read. Any error except io.EOF encountered +// during the read is also returned. +// +// This method is preferred over calling Write multiple times +// to maximise throughput for transferring the entire file, +// especially over high-latency links. +func (f *File) ReadFrom(r io.Reader) (int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.c.useConcurrentWrites { + var remain int64 + switch r := r.(type) { + case interface{ Len() int }: + remain = int64(r.Len()) + + case interface{ Size() int64 }: + remain = r.Size() + + case *io.LimitedReader: + remain = r.N + + case interface{ Stat() (os.FileInfo, error) }: + info, err := r.Stat() + if err == nil { + remain = info.Size() + } + } + + if remain < 0 { + // We can strongly assert that we want default max concurrency here. + return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests) + } + + if remain > int64(f.c.maxPacket) { + // Otherwise, only use concurrency, if it would be at least two packets. + + // This is the best reasonable guess we can make. + concurrency64 := remain/int64(f.c.maxPacket) + 1 + + // We need to cap this value to an `int` size value to avoid overflow on 32-bit machines. + // So, we may as well pre-cap it to `f.c.maxConcurrentRequests`. + if concurrency64 > int64(f.c.maxConcurrentRequests) { + concurrency64 = int64(f.c.maxConcurrentRequests) + } + + return f.ReadFromWithConcurrency(r, int(concurrency64)) + } + } + + ch := make(chan result, 1) // reusable channel + + b := make([]byte, f.c.maxPacket) + + var read int64 + for { + n, err := r.Read(b) + if n < 0 { + panic("sftp.File: reader returned negative count from Read") + } + + if n > 0 { + read += int64(n) + + m, err2 := f.writeChunkAt(ch, b[:n], f.offset) + f.offset += int64(m) + + if err == nil { + err = err2 + } + } + + if err != nil { + if err == io.EOF { + return read, nil // return nil explicitly. + } + + return read, err + } + } +} + +// Seek implements io.Seeker by setting the client offset for the next Read or +// Write. It returns the next offset read. Seeking before or after the end of +// the file is undefined. Seeking relative to the end calls Stat. +func (f *File) Seek(offset int64, whence int) (int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + + switch whence { + case io.SeekStart: + case io.SeekCurrent: + offset += f.offset + case io.SeekEnd: + fi, err := f.Stat() + if err != nil { + return f.offset, err + } + offset += fi.Size() + default: + return f.offset, unimplementedSeekWhence(whence) + } + + if offset < 0 { + return f.offset, os.ErrInvalid + } + + f.offset = offset + return f.offset, nil +} + +// Chown changes the uid/gid of the current file. +func (f *File) Chown(uid, gid int) error { + return f.c.Chown(f.path, uid, gid) +} + +// Chmod changes the permissions of the current file. +// +// See Client.Chmod for details. +func (f *File) Chmod(mode os.FileMode) error { + return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) +} + +// Sync requests a flush of the contents of a File to stable storage. +// +// Sync requires the server to support the fsync@openssh.com extension. +func (f *File) Sync() error { + id := f.c.nextID() + typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{ + ID: id, + Handle: f.handle, + }) + + switch { + case err != nil: + return err + case typ == sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return &unexpectedPacketErr{want: sshFxpStatus, got: typ} + } +} + +// Truncate sets the size of the current file. Although it may be safely assumed +// that if the size is less than its current size it will be truncated to fit, +// the SFTP protocol does not specify what behavior the server should do when setting +// size greater than the current size. +// We send a SSH_FXP_FSETSTAT here since we have a file handle +func (f *File) Truncate(size int64) error { + return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size)) +} + +// normaliseError normalises an error into a more standard form that can be +// checked against stdlib errors like io.EOF or os.ErrNotExist. +func normaliseError(err error) error { + switch err := err.(type) { + case *StatusError: + switch err.Code { + case sshFxEOF: + return io.EOF + case sshFxNoSuchFile: + return os.ErrNotExist + case sshFxPermissionDenied: + return os.ErrPermission + case sshFxOk: + return nil + default: + return err + } + default: + return err + } +} + +// flags converts the flags passed to OpenFile into ssh flags. +// Unsupported flags are ignored. +func flags(f int) uint32 { + var out uint32 + switch f & os.O_WRONLY { + case os.O_WRONLY: + out |= sshFxfWrite + case os.O_RDONLY: + out |= sshFxfRead + } + if f&os.O_RDWR == os.O_RDWR { + out |= sshFxfRead | sshFxfWrite + } + if f&os.O_APPEND == os.O_APPEND { + out |= sshFxfAppend + } + if f&os.O_CREATE == os.O_CREATE { + out |= sshFxfCreat + } + if f&os.O_TRUNC == os.O_TRUNC { + out |= sshFxfTrunc + } + if f&os.O_EXCL == os.O_EXCL { + out |= sshFxfExcl + } + return out +} + +// toChmodPerm converts Go permission bits to POSIX permission bits. +// +// This differs from fromFileMode in that we preserve the POSIX versions of +// setuid, setgid and sticky in m, because we've historically supported those +// bits, and we mask off any non-permission bits. +func toChmodPerm(m os.FileMode) (perm uint32) { + const mask = os.ModePerm | s_ISUID | s_ISGID | s_ISVTX + perm = uint32(m & mask) + + if m&os.ModeSetuid != 0 { + perm |= s_ISUID + } + if m&os.ModeSetgid != 0 { + perm |= s_ISGID + } + if m&os.ModeSticky != 0 { + perm |= s_ISVTX + } + + return perm +} diff --git a/sftp/conn.go b/sftp/conn.go new file mode 100644 index 0000000..7d95142 --- /dev/null +++ b/sftp/conn.go @@ -0,0 +1,189 @@ +package sftp + +import ( + "encoding" + "fmt" + "io" + "sync" +) + +// conn implements a bidirectional channel on which client and server +// connections are multiplexed. +type conn struct { + io.Reader + io.WriteCloser + // this is the same allocator used in packet manager + alloc *allocator + sync.Mutex // used to serialise writes to sendPacket +} + +// the orderID is used in server mode if the allocator is enabled. +// For the client mode just pass 0 +func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { + return recvPacket(c, c.alloc, orderID) +} + +func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { + c.Lock() + defer c.Unlock() + + return sendPacket(c, m) +} + +func (c *conn) Close() error { + c.Lock() + defer c.Unlock() + return c.WriteCloser.Close() +} + +type clientConn struct { + conn + wg sync.WaitGroup + + sync.Mutex // protects inflight + inflight map[uint32]chan<- result // outstanding requests + + closed chan struct{} + err error +} + +// Wait blocks until the conn has shut down, and return the error +// causing the shutdown. It can be called concurrently from multiple +// goroutines. +func (c *clientConn) Wait() error { + <-c.closed + return c.err +} + +// Close closes the SFTP session. +func (c *clientConn) Close() error { + defer c.wg.Wait() + return c.conn.Close() +} + +func (c *clientConn) loop() { + defer c.wg.Done() + err := c.recv() + if err != nil { + c.broadcastErr(err) + } +} + +// recv continuously reads from the server and forwards responses to the +// appropriate channel. +func (c *clientConn) recv() error { + defer c.conn.Close() + + for { + typ, data, err := c.recvPacket(0) + if err != nil { + return err + } + sid, _, err := unmarshalUint32Safe(data) + if err != nil { + return err + } + + ch, ok := c.getChannel(sid) + if !ok { + // This is an unexpected occurrence. Send the error + // back to all listeners so that they terminate + // gracefully. + return fmt.Errorf("sid not found: %d", sid) + } + + ch <- result{typ: typ, data: data} + } +} + +func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool { + c.Lock() + defer c.Unlock() + + select { + case <-c.closed: + // already closed with broadcastErr, return error on chan. + ch <- result{err: ErrSSHFxConnectionLost} + return false + default: + } + + c.inflight[sid] = ch + return true +} + +func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) { + c.Lock() + defer c.Unlock() + + ch, ok := c.inflight[sid] + delete(c.inflight, sid) + + return ch, ok +} + +// result captures the result of receiving the a packet from the server +type result struct { + typ byte + data []byte + err error +} + +type idmarshaler interface { + id() uint32 + encoding.BinaryMarshaler +} + +func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) { + if cap(ch) < 1 { + ch = make(chan result, 1) + } + + c.dispatchRequest(ch, p) + s := <-ch + return s.typ, s.data, s.err +} + +// dispatchRequest should ideally only be called by race-detection tests outside of this file, +// where you have to ensure two packets are in flight sequentially after each other. +func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { + sid := p.id() + + if !c.putChannel(ch, sid) { + // already closed. + return + } + + if err := c.conn.sendPacket(p); err != nil { + if ch, ok := c.getChannel(sid); ok { + ch <- result{err: err} + } + } +} + +// broadcastErr sends an error to all goroutines waiting for a response. +func (c *clientConn) broadcastErr(err error) { + c.Lock() + defer c.Unlock() + + bcastRes := result{err: ErrSSHFxConnectionLost} + for sid, ch := range c.inflight { + ch <- bcastRes + + // Replace the chan in inflight, + // we have hijacked this chan, + // and this guarantees always-only-once sending. + c.inflight[sid] = make(chan<- result, 1) + } + + c.err = err + close(c.closed) +} + +type serverConn struct { + conn +} + +func (s *serverConn) sendError(id uint32, err error) error { + return s.sendPacket(statusFromError(id, err)) +} diff --git a/sftp/debug.go b/sftp/debug.go new file mode 100644 index 0000000..3e264ab --- /dev/null +++ b/sftp/debug.go @@ -0,0 +1,9 @@ +// +build debug + +package sftp + +import "log" + +func debug(fmt string, args ...interface{}) { + log.Printf(fmt, args...) +} diff --git a/sftp/ls_formatting.go b/sftp/ls_formatting.go new file mode 100644 index 0000000..9c59070 --- /dev/null +++ b/sftp/ls_formatting.go @@ -0,0 +1,81 @@ +package sftp + +import ( + "errors" + "fmt" + "os" + "os/user" + "strconv" + "time" + + sshfx "git.deuxfleurs.fr/Deuxfleurs/bagage/internal/encoding/ssh/filexfer" +) + +func lsFormatID(id uint32) string { + return strconv.FormatUint(uint64(id), 10) +} + +type osIDLookup struct{} + +func (osIDLookup) Filelist(*Request) (ListerAt, error) { + return nil, errors.New("unimplemented stub") +} + +func (osIDLookup) LookupUserName(uid string) string { + u, err := user.LookupId(uid) + if err != nil { + return uid + } + + return u.Username +} + +func (osIDLookup) LookupGroupName(gid string) string { + g, err := user.LookupGroupId(gid) + if err != nil { + return gid + } + + return g.Name +} + +// runLs formats the FileInfo as per `ls -l` style, which is in the 'longname' field of a SSH_FXP_NAME entry. +// This is a fairly simple implementation, just enough to look close to openssh in simple cases. +func runLs(idLookup NameLookupFileLister, dirent os.FileInfo) string { + // example from openssh sftp server: + // crw-rw-rw- 1 root wheel 0 Jul 31 20:52 ttyvd + // format: + // {directory / char device / etc}{rwxrwxrwx} {number of links} owner group size month day [time (this year) | year (otherwise)] name + + symPerms := sshfx.FileMode(fromFileMode(dirent.Mode())).String() + + var numLinks uint64 = 1 + uid, gid := "0", "0" + + switch sys := dirent.Sys().(type) { + case *sshfx.Attributes: + uid = lsFormatID(sys.UID) + gid = lsFormatID(sys.GID) + case *FileStat: + uid = lsFormatID(sys.UID) + gid = lsFormatID(sys.GID) + default: + numLinks, uid, gid = lsLinksUIDGID(dirent) + } + + if idLookup != nil { + uid, gid = idLookup.LookupUserName(uid), idLookup.LookupGroupName(gid) + } + + mtime := dirent.ModTime() + date := mtime.Format("Jan 2") + + var yearOrTime string + if mtime.Before(time.Now().AddDate(0, -6, 0)) { + yearOrTime = mtime.Format("2006") + } else { + yearOrTime = mtime.Format("15:04") + } + + return fmt.Sprintf("%s %4d %-8s %-8s %8d %s %5s %s", symPerms, numLinks, uid, gid, dirent.Size(), date, yearOrTime, dirent.Name()) +} diff --git a/sftp/ls_plan9.go b/sftp/ls_plan9.go new file mode 100644 index 0000000..a16a3ea --- /dev/null +++ b/sftp/ls_plan9.go @@ -0,0 +1,21 @@ +// +build plan9 + +package sftp + +import ( + "os" + "syscall" +) + +func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) { + numLinks = 1 + uid, gid = "0", "0" + + switch sys := fi.Sys().(type) { + case *syscall.Dir: + uid = sys.Uid + gid = sys.Gid + } + + return numLinks, uid, gid +} diff --git a/sftp/ls_stub.go b/sftp/ls_stub.go new file mode 100644 index 0000000..6dec393 --- /dev/null +++ b/sftp/ls_stub.go @@ -0,0 +1,11 @@ +// +build windows android + +package sftp + +import ( + "os" +) + +func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) { + return 1, "0", "0" +} diff --git a/sftp/ls_unix.go b/sftp/ls_unix.go new file mode 100644 index 0000000..59ccffd --- /dev/null +++ b/sftp/ls_unix.go @@ -0,0 +1,23 @@ +// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js + +package sftp + +import ( + "os" + "syscall" +) + +func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) { + numLinks = 1 + uid, gid = "0", "0" + + switch sys := fi.Sys().(type) { + case *syscall.Stat_t: + numLinks = uint64(sys.Nlink) + uid = lsFormatID(sys.Uid) + gid = lsFormatID(sys.Gid) + default: + } + + return numLinks, uid, gid +} diff --git a/sftp/packet.go b/sftp/packet.go new file mode 100644 index 0000000..50ca069 --- /dev/null +++ b/sftp/packet.go @@ -0,0 +1,1276 @@ +package sftp + +import ( + "bytes" + "encoding" + "encoding/binary" + "errors" + "fmt" + "io" + "os" + "reflect" +) + +var ( + errLongPacket = errors.New("packet too long") + errShortPacket = errors.New("packet too short") + errUnknownExtendedPacket = errors.New("unknown extended packet") +) + +const ( + maxMsgLength = 256 * 1024 + debugDumpTxPacket = false + debugDumpRxPacket = false + debugDumpTxPacketBytes = false + debugDumpRxPacketBytes = false +) + +func marshalUint32(b []byte, v uint32) []byte { + return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +func marshalUint64(b []byte, v uint64) []byte { + return marshalUint32(marshalUint32(b, uint32(v>>32)), uint32(v)) +} + +func marshalString(b []byte, v string) []byte { + return append(marshalUint32(b, uint32(len(v))), v...) +} + +func marshalFileInfo(b []byte, fi os.FileInfo) []byte { + // attributes variable struct, and also variable per protocol version + // spec version 3 attributes: + // uint32 flags + // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE + // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS + // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED + // string extended_type + // string extended_data + // ... more extended data (extended_type - extended_data pairs), + // so that number of pairs equals extended_count + + flags, fileStat := fileStatFromInfo(fi) + + b = marshalUint32(b, flags) + if flags&sshFileXferAttrSize != 0 { + b = marshalUint64(b, fileStat.Size) + } + if flags&sshFileXferAttrUIDGID != 0 { + b = marshalUint32(b, fileStat.UID) + b = marshalUint32(b, fileStat.GID) + } + if flags&sshFileXferAttrPermissions != 0 { + b = marshalUint32(b, fileStat.Mode) + } + if flags&sshFileXferAttrACmodTime != 0 { + b = marshalUint32(b, fileStat.Atime) + b = marshalUint32(b, fileStat.Mtime) + } + + return b +} + +func marshalStatus(b []byte, err StatusError) []byte { + b = marshalUint32(b, err.Code) + b = marshalString(b, err.msg) + b = marshalString(b, err.lang) + return b +} + +func marshal(b []byte, v interface{}) []byte { + if v == nil { + return b + } + switch v := v.(type) { + case uint8: + return append(b, v) + case uint32: + return marshalUint32(b, v) + case uint64: + return marshalUint64(b, v) + case string: + return marshalString(b, v) + case os.FileInfo: + return marshalFileInfo(b, v) + default: + switch d := reflect.ValueOf(v); d.Kind() { + case reflect.Struct: + for i, n := 0, d.NumField(); i < n; i++ { + b = marshal(b, d.Field(i).Interface()) + } + return b + case reflect.Slice: + for i, n := 0, d.Len(); i < n; i++ { + b = marshal(b, d.Index(i).Interface()) + } + return b + default: + panic(fmt.Sprintf("marshal(%#v): cannot handle type %T", v, v)) + } + } +} + +func unmarshalUint32(b []byte) (uint32, []byte) { + v := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 + return v, b[4:] +} + +func unmarshalUint32Safe(b []byte) (uint32, []byte, error) { + var v uint32 + if len(b) < 4 { + return 0, nil, errShortPacket + } + v, b = unmarshalUint32(b) + return v, b, nil +} + +func unmarshalUint64(b []byte) (uint64, []byte) { + h, b := unmarshalUint32(b) + l, b := unmarshalUint32(b) + return uint64(h)<<32 | uint64(l), b +} + +func unmarshalUint64Safe(b []byte) (uint64, []byte, error) { + var v uint64 + if len(b) < 8 { + return 0, nil, errShortPacket + } + v, b = unmarshalUint64(b) + return v, b, nil +} + +func unmarshalString(b []byte) (string, []byte) { + n, b := unmarshalUint32(b) + return string(b[:n]), b[n:] +} + +func unmarshalStringSafe(b []byte) (string, []byte, error) { + n, b, err := unmarshalUint32Safe(b) + if err != nil { + return "", nil, err + } + if int64(n) > int64(len(b)) { + return "", nil, errShortPacket + } + return string(b[:n]), b[n:], nil +} + +func unmarshalAttrs(b []byte) (*FileStat, []byte) { + flags, b := unmarshalUint32(b) + return unmarshalFileStat(flags, b) +} + +func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { + var fs FileStat + if flags&sshFileXferAttrSize == sshFileXferAttrSize { + fs.Size, b, _ = unmarshalUint64Safe(b) + } + if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { + fs.UID, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { + fs.GID, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { + fs.Mode, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime { + fs.Atime, b, _ = unmarshalUint32Safe(b) + fs.Mtime, b, _ = unmarshalUint32Safe(b) + } + if flags&sshFileXferAttrExtended == sshFileXferAttrExtended { + var count uint32 + count, b, _ = unmarshalUint32Safe(b) + ext := make([]StatExtended, count) + for i := uint32(0); i < count; i++ { + var typ string + var data string + typ, b, _ = unmarshalStringSafe(b) + data, b, _ = unmarshalStringSafe(b) + ext[i] = StatExtended{ + ExtType: typ, + ExtData: data, + } + } + fs.Extended = ext + } + return &fs, b +} + +func unmarshalStatus(id uint32, data []byte) error { + sid, data := unmarshalUint32(data) + if sid != id { + return &unexpectedIDErr{id, sid} + } + code, data := unmarshalUint32(data) + msg, data, _ := unmarshalStringSafe(data) + lang, _, _ := unmarshalStringSafe(data) + return &StatusError{ + Code: code, + msg: msg, + lang: lang, + } +} + +type packetMarshaler interface { + marshalPacket() (header, payload []byte, err error) +} + +func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err error) { + if m, ok := m.(packetMarshaler); ok { + return m.marshalPacket() + } + + header, err = m.MarshalBinary() + return +} + +// sendPacket marshals p according to RFC 4234. +func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { + header, payload, err := marshalPacket(m) + if err != nil { + return fmt.Errorf("binary marshaller failed: %w", err) + } + + length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start + if debugDumpTxPacketBytes { + debug("send packet: %s %d bytes %x%x", fxp(header[4]), length, header[5:], payload) + } else if debugDumpTxPacket { + debug("send packet: %s %d bytes", fxp(header[4]), length) + } + + binary.BigEndian.PutUint32(header[:4], uint32(length)) + + if _, err := w.Write(header); err != nil { + return fmt.Errorf("failed to send packet: %w", err) + } + + if len(payload) > 0 { + if _, err := w.Write(payload); err != nil { + return fmt.Errorf("failed to send packet payload: %w", err) + } + } + + return nil +} + +func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) { + var b []byte + if alloc != nil { + b = alloc.GetPage(orderID) + } else { + b = make([]byte, 4) + } + if _, err := io.ReadFull(r, b[:4]); err != nil { + return 0, nil, err + } + length, _ := unmarshalUint32(b) + if length > maxMsgLength { + debug("recv packet %d bytes too long", length) + return 0, nil, errLongPacket + } + if length == 0 { + debug("recv packet of 0 bytes too short") + return 0, nil, errShortPacket + } + if alloc == nil { + b = make([]byte, length) + } + if _, err := io.ReadFull(r, b[:length]); err != nil { + debug("recv packet %d bytes: err %v", length, err) + return 0, nil, err + } + if debugDumpRxPacketBytes { + debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length]) + } else if debugDumpRxPacket { + debug("recv packet: %s %d bytes", fxp(b[0]), length) + } + return b[0], b[1:length], nil +} + +type extensionPair struct { + Name string + Data string +} + +func unmarshalExtensionPair(b []byte) (extensionPair, []byte, error) { + var ep extensionPair + var err error + ep.Name, b, err = unmarshalStringSafe(b) + if err != nil { + return ep, b, err + } + ep.Data, b, err = unmarshalStringSafe(b) + return ep, b, err +} + +// Here starts the definition of packets along with their MarshalBinary +// implementations. +// Manually writing the marshalling logic wins us a lot of time and +// allocation. + +type sshFxInitPacket struct { + Version uint32 + Extensions []extensionPair +} + +func (p *sshFxInitPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version) + for _, e := range p.Extensions { + l += 4 + len(e.Name) + 4 + len(e.Data) + } + + b := make([]byte, 4, l) + b = append(b, sshFxpInit) + b = marshalUint32(b, p.Version) + + for _, e := range p.Extensions { + b = marshalString(b, e.Name) + b = marshalString(b, e.Data) + } + + return b, nil +} + +func (p *sshFxInitPacket) UnmarshalBinary(b []byte) error { + var err error + if p.Version, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + for len(b) > 0 { + var ep extensionPair + ep, b, err = unmarshalExtensionPair(b) + if err != nil { + return err + } + p.Extensions = append(p.Extensions, ep) + } + return nil +} + +type sshFxVersionPacket struct { + Version uint32 + Extensions []sshExtensionPair +} + +type sshExtensionPair struct { + Name, Data string +} + +func (p *sshFxVersionPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version) + for _, e := range p.Extensions { + l += 4 + len(e.Name) + 4 + len(e.Data) + } + + b := make([]byte, 4, l) + b = append(b, sshFxpVersion) + b = marshalUint32(b, p.Version) + + for _, e := range p.Extensions { + b = marshalString(b, e.Name) + b = marshalString(b, e.Data) + } + + return b, nil +} + +func marshalIDStringPacket(packetType byte, id uint32, str string) ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(str) + + b := make([]byte, 4, l) + b = append(b, packetType) + b = marshalUint32(b, id) + b = marshalString(b, str) + + return b, nil +} + +func unmarshalIDString(b []byte, id *uint32, str *string) error { + var err error + *id, b, err = unmarshalUint32Safe(b) + if err != nil { + return err + } + *str, _, err = unmarshalStringSafe(b) + return err +} + +type sshFxpReaddirPacket struct { + ID uint32 + Handle string +} + +func (p *sshFxpReaddirPacket) id() uint32 { return p.ID } + +func (p *sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpReaddir, p.ID, p.Handle) +} + +func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) +} + +type sshFxpOpendirPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpOpendirPacket) id() uint32 { return p.ID } + +func (p *sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpOpendir, p.ID, p.Path) +} + +func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpLstatPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpLstatPacket) id() uint32 { return p.ID } + +func (p *sshFxpLstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpLstat, p.ID, p.Path) +} + +func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpStatPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpStatPacket) id() uint32 { return p.ID } + +func (p *sshFxpStatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpStat, p.ID, p.Path) +} + +func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpFstatPacket struct { + ID uint32 + Handle string +} + +func (p *sshFxpFstatPacket) id() uint32 { return p.ID } + +func (p *sshFxpFstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpFstat, p.ID, p.Handle) +} + +func (p *sshFxpFstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) +} + +type sshFxpClosePacket struct { + ID uint32 + Handle string +} + +func (p *sshFxpClosePacket) id() uint32 { return p.ID } + +func (p *sshFxpClosePacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpClose, p.ID, p.Handle) +} + +func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) +} + +type sshFxpRemovePacket struct { + ID uint32 + Filename string +} + +func (p *sshFxpRemovePacket) id() uint32 { return p.ID } + +func (p *sshFxpRemovePacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRemove, p.ID, p.Filename) +} + +func (p *sshFxpRemovePacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Filename) +} + +type sshFxpRmdirPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpRmdirPacket) id() uint32 { return p.ID } + +func (p *sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRmdir, p.ID, p.Path) +} + +func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpSymlinkPacket struct { + ID uint32 + Targetpath string + Linkpath string +} + +func (p *sshFxpSymlinkPacket) id() uint32 { return p.ID } + +func (p *sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Targetpath) + + 4 + len(p.Linkpath) + + b := make([]byte, 4, l) + b = append(b, sshFxpSymlink) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Targetpath) + b = marshalString(b, p.Linkpath) + + return b, nil +} + +func (p *sshFxpSymlinkPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Targetpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Linkpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpHardlinkPacket struct { + ID uint32 + Oldpath string + Newpath string +} + +func (p *sshFxpHardlinkPacket) id() uint32 { return p.ID } + +func (p *sshFxpHardlinkPacket) MarshalBinary() ([]byte, error) { + const ext = "hardlink@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Oldpath) + + 4 + len(p.Newpath) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Oldpath) + b = marshalString(b, p.Newpath) + + return b, nil +} + +type sshFxpReadlinkPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpReadlinkPacket) id() uint32 { return p.ID } + +func (p *sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpReadlink, p.ID, p.Path) +} + +func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpRealpathPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpRealpathPacket) id() uint32 { return p.ID } + +func (p *sshFxpRealpathPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRealpath, p.ID, p.Path) +} + +func (p *sshFxpRealpathPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpNameAttr struct { + Name string + LongName string + Attrs []interface{} +} + +func (p *sshFxpNameAttr) MarshalBinary() ([]byte, error) { + var b []byte + b = marshalString(b, p.Name) + b = marshalString(b, p.LongName) + for _, attr := range p.Attrs { + b = marshal(b, attr) + } + return b, nil +} + +type sshFxpNamePacket struct { + ID uint32 + NameAttrs []*sshFxpNameAttr +} + +func (p *sshFxpNamePacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpName) + b = marshalUint32(b, p.ID) + b = marshalUint32(b, uint32(len(p.NameAttrs))) + + var payload []byte + for _, na := range p.NameAttrs { + ab, err := na.MarshalBinary() + if err != nil { + return nil, nil, err + } + + payload = append(payload, ab...) + } + + return b, payload, nil +} + +func (p *sshFxpNamePacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +type sshFxpOpenPacket struct { + ID uint32 + Path string + Pflags uint32 + Flags uint32 // ignored +} + +func (p *sshFxpOpenPacket) id() uint32 { return p.ID } + +func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Path) + + 4 + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpOpen) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Pflags) + b = marshalUint32(b, p.Flags) + + return b, nil +} + +func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + +type sshFxpReadPacket struct { + ID uint32 + Len uint32 + Offset uint64 + Handle string +} + +func (p *sshFxpReadPacket) id() uint32 { return p.ID } + +func (p *sshFxpReadPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + 8 + 4 // uint64 + uint32 + + b := make([]byte, 4, l) + b = append(b, sshFxpRead) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint64(b, p.Offset) + b = marshalUint32(b, p.Len) + + return b, nil +} + +func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return err + } else if p.Len, _, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + +// We need allocate bigger slices with extra capacity to avoid a re-allocation in sshFxpDataPacket.MarshalBinary +// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length) +const dataHeaderLen = 4 + 1 + 4 + 4 + +func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { + dataLen := p.Len + if dataLen > maxTxPacket { + dataLen = maxTxPacket + } + + if alloc != nil { + // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in + // sshFxpDataPacket.MarshalBinary + return alloc.GetPage(orderID)[:dataLen] + } + + // allocate with extra space for the header + return make([]byte, dataLen, dataLen+dataHeaderLen) +} + +type sshFxpRenamePacket struct { + ID uint32 + Oldpath string + Newpath string +} + +func (p *sshFxpRenamePacket) id() uint32 { return p.ID } + +func (p *sshFxpRenamePacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Oldpath) + + 4 + len(p.Newpath) + + b := make([]byte, 4, l) + b = append(b, sshFxpRename) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Oldpath) + b = marshalString(b, p.Newpath) + + return b, nil +} + +func (p *sshFxpRenamePacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpPosixRenamePacket struct { + ID uint32 + Oldpath string + Newpath string +} + +func (p *sshFxpPosixRenamePacket) id() uint32 { return p.ID } + +func (p *sshFxpPosixRenamePacket) MarshalBinary() ([]byte, error) { + const ext = "posix-rename@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Oldpath) + + 4 + len(p.Newpath) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Oldpath) + b = marshalString(b, p.Newpath) + + return b, nil +} + +type sshFxpWritePacket struct { + ID uint32 + Length uint32 + Offset uint64 + Handle string + Data []byte +} + +func (p *sshFxpWritePacket) id() uint32 { return p.ID } + +func (p *sshFxpWritePacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + 8 + // uint64 + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpWrite) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint64(b, p.Offset) + b = marshalUint32(b, p.Length) + + return b, p.Data, nil +} + +func (p *sshFxpWritePacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return err + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if uint32(len(b)) < p.Length { + return errShortPacket + } + + p.Data = b[:p.Length] + return nil +} + +type sshFxpMkdirPacket struct { + ID uint32 + Flags uint32 // ignored + Path string +} + +func (p *sshFxpMkdirPacket) id() uint32 { return p.ID } + +func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Path) + + 4 // uint32 + + b := make([]byte, 4, l) + b = append(b, sshFxpMkdir) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Flags) + + return b, nil +} + +func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + +type sshFxpSetstatPacket struct { + ID uint32 + Flags uint32 + Path string + Attrs interface{} +} + +type sshFxpFsetstatPacket struct { + ID uint32 + Flags uint32 + Handle string + Attrs interface{} +} + +func (p *sshFxpSetstatPacket) id() uint32 { return p.ID } +func (p *sshFxpFsetstatPacket) id() uint32 { return p.ID } + +func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Path) + + 4 // uint32 + + b := make([]byte, 4, l) + b = append(b, sshFxpSetstat) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Path) + b = marshalUint32(b, p.Flags) + + payload := marshal(nil, p.Attrs) + + return b, payload, nil +} + +func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + 4 // uint32 + + b := make([]byte, 4, l) + b = append(b, sshFxpFsetstat) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint32(b, p.Flags) + + payload := marshal(nil, p.Attrs) + + return b, payload, nil +} + +func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +type sshFxpHandlePacket struct { + ID uint32 + Handle string +} + +func (p *sshFxpHandlePacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + b := make([]byte, 4, l) + b = append(b, sshFxpHandle) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + + return b, nil +} + +type sshFxpStatusPacket struct { + ID uint32 + StatusError +} + +func (p *sshFxpStatusPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + 4 + len(p.StatusError.msg) + + 4 + len(p.StatusError.lang) + + b := make([]byte, 4, l) + b = append(b, sshFxpStatus) + b = marshalUint32(b, p.ID) + b = marshalStatus(b, p.StatusError) + + return b, nil +} + +type sshFxpDataPacket struct { + ID uint32 + Length uint32 + Data []byte +} + +func (p *sshFxpDataPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpData) + b = marshalUint32(b, p.ID) + b = marshalUint32(b, p.Length) + + return b, p.Data, nil +} + +// MarshalBinary encodes the receiver into a binary form and returns the result. +// To avoid a new allocation the Data slice must have a capacity >= Length + 9 +// +// This is hand-coded rather than just append(header, payload...), +// in order to try and reuse the r.Data backing store in the packet. +func (p *sshFxpDataPacket) MarshalBinary() ([]byte, error) { + b := append(p.Data, make([]byte, dataHeaderLen)...) + copy(b[dataHeaderLen:], p.Data[:p.Length]) + // b[0:4] will be overwritten with the length in sendPacket + b[4] = sshFxpData + binary.BigEndian.PutUint32(b[5:9], p.ID) + binary.BigEndian.PutUint32(b[9:13], p.Length) + return b, nil +} + +func (p *sshFxpDataPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if uint32(len(b)) < p.Length { + return errShortPacket + } + + p.Data = b[:p.Length] + return nil +} + +type sshFxpStatvfsPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpStatvfsPacket) id() uint32 { return p.ID } + +func (p *sshFxpStatvfsPacket) MarshalBinary() ([]byte, error) { + const ext = "statvfs@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Path) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Path) + + return b, nil +} + +// A StatVFS contains statistics about a filesystem. +type StatVFS struct { + ID uint32 + Bsize uint64 /* file system block size */ + Frsize uint64 /* fundamental fs block size */ + Blocks uint64 /* number of blocks (unit f_frsize) */ + Bfree uint64 /* free blocks in file system */ + Bavail uint64 /* free blocks for non-root */ + Files uint64 /* total file inodes */ + Ffree uint64 /* free file inodes */ + Favail uint64 /* free file inodes for to non-root */ + Fsid uint64 /* file system id */ + Flag uint64 /* bit mask of f_flag values */ + Namemax uint64 /* maximum filename length */ +} + +// TotalSpace calculates the amount of total space in a filesystem. +func (p *StatVFS) TotalSpace() uint64 { + return p.Frsize * p.Blocks +} + +// FreeSpace calculates the amount of free space in a filesystem. +func (p *StatVFS) FreeSpace() uint64 { + return p.Frsize * p.Bfree +} + +// marshalPacket converts to ssh_FXP_EXTENDED_REPLY packet binary format +func (p *StatVFS) marshalPacket() ([]byte, []byte, error) { + header := []byte{0, 0, 0, 0, sshFxpExtendedReply} + + var buf bytes.Buffer + err := binary.Write(&buf, binary.BigEndian, p) + + return header, buf.Bytes(), err +} + +// MarshalBinary encodes the StatVFS as an SSH_FXP_EXTENDED_REPLY packet. +func (p *StatVFS) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +type sshFxpFsyncPacket struct { + ID uint32 + Handle string +} + +func (p *sshFxpFsyncPacket) id() uint32 { return p.ID } + +func (p *sshFxpFsyncPacket) MarshalBinary() ([]byte, error) { + const ext = "fsync@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Handle) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Handle) + + return b, nil +} + +type sshFxpExtendedPacket struct { + ID uint32 + ExtendedRequest string + SpecificPacket interface { + serverRespondablePacket + readonly() bool + } +} + +func (p *sshFxpExtendedPacket) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacket) readonly() bool { + if p.SpecificPacket == nil { + return true + } + return p.SpecificPacket.readonly() +} + +func (p *sshFxpExtendedPacket) respond(svr *Server) responsePacket { + if p.SpecificPacket == nil { + return statusFromError(p.ID, nil) + } + return p.SpecificPacket.respond(svr) +} + +func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error { + var err error + bOrig := b + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, _, err = unmarshalStringSafe(b); err != nil { + return err + } + + // specific unmarshalling + switch p.ExtendedRequest { + case "statvfs@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketStatVFS{} + case "posix-rename@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketPosixRename{} + case "hardlink@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketHardlink{} + default: + return fmt.Errorf("packet type %v: %w", p.SpecificPacket, errUnknownExtendedPacket) + } + + return p.SpecificPacket.UnmarshalBinary(bOrig) +} + +type sshFxpExtendedPacketStatVFS struct { + ID uint32 + ExtendedRequest string + Path string +} + +func (p *sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketStatVFS) readonly() bool { return true } +func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Path, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpExtendedPacketPosixRename struct { + ID uint32 + ExtendedRequest string + Oldpath string + Newpath string +} + +func (p *sshFxpExtendedPacketPosixRename) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketPosixRename) readonly() bool { return false } +func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { + err := os.Rename(p.Oldpath, p.Newpath) + return statusFromError(p.ID, err) +} + +type sshFxpExtendedPacketHardlink struct { + ID uint32 + ExtendedRequest string + Oldpath string + Newpath string +} + +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +func (p *sshFxpExtendedPacketHardlink) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketHardlink) readonly() bool { return true } +func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket { + err := os.Link(p.Oldpath, p.Newpath) + return statusFromError(p.ID, err) +} diff --git a/sftp/packet_manager.go b/sftp/packet_manager.go new file mode 100644 index 0000000..5aeb72b --- /dev/null +++ b/sftp/packet_manager.go @@ -0,0 +1,221 @@ +package sftp + +/* + Imported from: https://github.com/pkg/sftp + */ + +import ( + "encoding" + "sort" + "sync" +) + +// The goal of the packetManager is to keep the outgoing packets in the same +// order as the incoming as is requires by section 7 of the RFC. + +type packetManager struct { + requests chan orderedPacket + responses chan orderedPacket + fini chan struct{} + incoming orderedPackets + outgoing orderedPackets + sender packetSender // connection object + working *sync.WaitGroup + packetCount uint32 + // it is not nil if the allocator is enabled + alloc *allocator +} + +type packetSender interface { + sendPacket(encoding.BinaryMarshaler) error +} + +func newPktMgr(sender packetSender) *packetManager { + s := &packetManager{ + requests: make(chan orderedPacket, SftpServerWorkerCount), + responses: make(chan orderedPacket, SftpServerWorkerCount), + fini: make(chan struct{}), + incoming: make([]orderedPacket, 0, SftpServerWorkerCount), + outgoing: make([]orderedPacket, 0, SftpServerWorkerCount), + sender: sender, + working: &sync.WaitGroup{}, + } + go s.controller() + return s +} + +//// packet ordering +func (s *packetManager) newOrderID() uint32 { + s.packetCount++ + return s.packetCount +} + +// returns the next orderID without incrementing it. +// This is used before receiving a new packet, with the allocator enabled, to associate +// the slice allocated for the received packet with the orderID that will be used to mark +// the allocated slices for reuse once the request is served +func (s *packetManager) getNextOrderID() uint32 { + return s.packetCount + 1 +} + +type orderedRequest struct { + requestPacket + orderid uint32 +} + +func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest { + return orderedRequest{requestPacket: p, orderid: s.newOrderID()} +} +func (p orderedRequest) orderID() uint32 { return p.orderid } +func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid } + +type orderedResponse struct { + responsePacket + orderid uint32 +} + +func (s *packetManager) newOrderedResponse(p responsePacket, id uint32, +) orderedResponse { + return orderedResponse{responsePacket: p, orderid: id} +} +func (p orderedResponse) orderID() uint32 { return p.orderid } +func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid } + +type orderedPacket interface { + id() uint32 + orderID() uint32 +} +type orderedPackets []orderedPacket + +func (o orderedPackets) Sort() { + sort.Slice(o, func(i, j int) bool { + return o[i].orderID() < o[j].orderID() + }) +} + +//// packet registry +// register incoming packets to be handled +func (s *packetManager) incomingPacket(pkt orderedRequest) { + s.working.Add(1) + s.requests <- pkt +} + +// register outgoing packets as being ready +func (s *packetManager) readyPacket(pkt orderedResponse) { + s.responses <- pkt + s.working.Done() +} + +// shut down packetManager controller +func (s *packetManager) close() { + // pause until current packets are processed + s.working.Wait() + close(s.fini) +} + +// Passed a worker function, returns a channel for incoming packets. +// Keep process packet responses in the order they are received while +// maximizing throughput of file transfers. +func (s *packetManager) workerChan(runWorker func(chan orderedRequest), +) chan orderedRequest { + // multiple workers for faster read/writes + rwChan := make(chan orderedRequest, SftpServerWorkerCount) + for i := 0; i < SftpServerWorkerCount; i++ { + runWorker(rwChan) + } + + // single worker to enforce sequential processing of everything else + cmdChan := make(chan orderedRequest) + runWorker(cmdChan) + + pktChan := make(chan orderedRequest, SftpServerWorkerCount) + go func() { + for pkt := range pktChan { + switch pkt.requestPacket.(type) { + case *sshFxpReadPacket, *sshFxpWritePacket: + s.incomingPacket(pkt) + rwChan <- pkt + continue + case *sshFxpClosePacket: + // wait for reads/writes to finish when file is closed + // incomingPacket() call must occur after this + s.working.Wait() + } + s.incomingPacket(pkt) + // all non-RW use sequential cmdChan + cmdChan <- pkt + } + close(rwChan) + close(cmdChan) + s.close() + }() + + return pktChan +} + +// process packets +func (s *packetManager) controller() { + for { + select { + case pkt := <-s.requests: + debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID()) + s.incoming = append(s.incoming, pkt) + s.incoming.Sort() + case pkt := <-s.responses: + debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID()) + s.outgoing = append(s.outgoing, pkt) + s.outgoing.Sort() + case <-s.fini: + return + } + s.maybeSendPackets() + } +} + +// send as many packets as are ready +func (s *packetManager) maybeSendPackets() { + for { + if len(s.outgoing) == 0 || len(s.incoming) == 0 { + debug("break! -- outgoing: %v; incoming: %v", + len(s.outgoing), len(s.incoming)) + break + } + out := s.outgoing[0] + in := s.incoming[0] + // debug("incoming: %v", ids(s.incoming)) + // debug("outgoing: %v", ids(s.outgoing)) + if in.orderID() == out.orderID() { + debug("Sending packet: %v", out.id()) + s.sender.sendPacket(out.(encoding.BinaryMarshaler)) + if s.alloc != nil { + // mark for reuse the slices allocated for this request + s.alloc.ReleasePages(in.orderID()) + } + // pop off heads + copy(s.incoming, s.incoming[1:]) // shift left + s.incoming[len(s.incoming)-1] = nil // clear last + s.incoming = s.incoming[:len(s.incoming)-1] // remove last + copy(s.outgoing, s.outgoing[1:]) // shift left + s.outgoing[len(s.outgoing)-1] = nil // clear last + s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last + } else { + break + } + } +} + +// func oids(o []orderedPacket) []uint32 { +// res := make([]uint32, 0, len(o)) +// for _, v := range o { +// res = append(res, v.orderId()) +// } +// return res +// } +// func ids(o []orderedPacket) []uint32 { +// res := make([]uint32, 0, len(o)) +// for _, v := range o { +// res = append(res, v.id()) +// } +// return res +// } + diff --git a/sftp/packet_typing.go b/sftp/packet_typing.go new file mode 100644 index 0000000..f4f9052 --- /dev/null +++ b/sftp/packet_typing.go @@ -0,0 +1,135 @@ +package sftp + +import ( + "encoding" + "fmt" +) + +// all incoming packets +type requestPacket interface { + encoding.BinaryUnmarshaler + id() uint32 +} + +type responsePacket interface { + encoding.BinaryMarshaler + id() uint32 +} + +// interfaces to group types +type hasPath interface { + requestPacket + getPath() string +} + +type hasHandle interface { + requestPacket + getHandle() string +} + +type notReadOnly interface { + notReadOnly() +} + +//// define types by adding methods +// hasPath +func (p *sshFxpLstatPacket) getPath() string { return p.Path } +func (p *sshFxpStatPacket) getPath() string { return p.Path } +func (p *sshFxpRmdirPacket) getPath() string { return p.Path } +func (p *sshFxpReadlinkPacket) getPath() string { return p.Path } +func (p *sshFxpRealpathPacket) getPath() string { return p.Path } +func (p *sshFxpMkdirPacket) getPath() string { return p.Path } +func (p *sshFxpSetstatPacket) getPath() string { return p.Path } +func (p *sshFxpStatvfsPacket) getPath() string { return p.Path } +func (p *sshFxpRemovePacket) getPath() string { return p.Filename } +func (p *sshFxpRenamePacket) getPath() string { return p.Oldpath } +func (p *sshFxpSymlinkPacket) getPath() string { return p.Targetpath } +func (p *sshFxpOpendirPacket) getPath() string { return p.Path } +func (p *sshFxpOpenPacket) getPath() string { return p.Path } + +func (p *sshFxpExtendedPacketPosixRename) getPath() string { return p.Oldpath } +func (p *sshFxpExtendedPacketHardlink) getPath() string { return p.Oldpath } + +// getHandle +func (p *sshFxpFstatPacket) getHandle() string { return p.Handle } +func (p *sshFxpFsetstatPacket) getHandle() string { return p.Handle } +func (p *sshFxpReadPacket) getHandle() string { return p.Handle } +func (p *sshFxpWritePacket) getHandle() string { return p.Handle } +func (p *sshFxpReaddirPacket) getHandle() string { return p.Handle } +func (p *sshFxpClosePacket) getHandle() string { return p.Handle } + +// notReadOnly +func (p *sshFxpWritePacket) notReadOnly() {} +func (p *sshFxpSetstatPacket) notReadOnly() {} +func (p *sshFxpFsetstatPacket) notReadOnly() {} +func (p *sshFxpRemovePacket) notReadOnly() {} +func (p *sshFxpMkdirPacket) notReadOnly() {} +func (p *sshFxpRmdirPacket) notReadOnly() {} +func (p *sshFxpRenamePacket) notReadOnly() {} +func (p *sshFxpSymlinkPacket) notReadOnly() {} +func (p *sshFxpExtendedPacketPosixRename) notReadOnly() {} +func (p *sshFxpExtendedPacketHardlink) notReadOnly() {} + +// some packets with ID are missing id() +func (p *sshFxpDataPacket) id() uint32 { return p.ID } +func (p *sshFxpStatusPacket) id() uint32 { return p.ID } +func (p *sshFxpStatResponse) id() uint32 { return p.ID } +func (p *sshFxpNamePacket) id() uint32 { return p.ID } +func (p *sshFxpHandlePacket) id() uint32 { return p.ID } +func (p *StatVFS) id() uint32 { return p.ID } +func (p *sshFxVersionPacket) id() uint32 { return 0 } + +// take raw incoming packet data and build packet objects +func makePacket(p rxPacket) (requestPacket, error) { + var pkt requestPacket + switch p.pktType { + case sshFxpInit: + pkt = &sshFxInitPacket{} + case sshFxpLstat: + pkt = &sshFxpLstatPacket{} + case sshFxpOpen: + pkt = &sshFxpOpenPacket{} + case sshFxpClose: + pkt = &sshFxpClosePacket{} + case sshFxpRead: + pkt = &sshFxpReadPacket{} + case sshFxpWrite: + pkt = &sshFxpWritePacket{} + case sshFxpFstat: + pkt = &sshFxpFstatPacket{} + case sshFxpSetstat: + pkt = &sshFxpSetstatPacket{} + case sshFxpFsetstat: + pkt = &sshFxpFsetstatPacket{} + case sshFxpOpendir: + pkt = &sshFxpOpendirPacket{} + case sshFxpReaddir: + pkt = &sshFxpReaddirPacket{} + case sshFxpRemove: + pkt = &sshFxpRemovePacket{} + case sshFxpMkdir: + pkt = &sshFxpMkdirPacket{} + case sshFxpRmdir: + pkt = &sshFxpRmdirPacket{} + case sshFxpRealpath: + pkt = &sshFxpRealpathPacket{} + case sshFxpStat: + pkt = &sshFxpStatPacket{} + case sshFxpRename: + pkt = &sshFxpRenamePacket{} + case sshFxpReadlink: + pkt = &sshFxpReadlinkPacket{} + case sshFxpSymlink: + pkt = &sshFxpSymlinkPacket{} + case sshFxpExtended: + pkt = &sshFxpExtendedPacket{} + default: + return nil, fmt.Errorf("unhandled packet type: %s", p.pktType) + } + if err := pkt.UnmarshalBinary(p.pktBytes); err != nil { + // Return partially unpacked packet to allow callers to return + // error messages appropriately with necessary id() method. + return pkt, err + } + return pkt, nil +} diff --git a/sftp/pool.go b/sftp/pool.go new file mode 100644 index 0000000..3612629 --- /dev/null +++ b/sftp/pool.go @@ -0,0 +1,79 @@ +package sftp + +// bufPool provides a pool of byte-slices to be reused in various parts of the package. +// It is safe to use concurrently through a pointer. +type bufPool struct { + ch chan []byte + blen int +} + +func newBufPool(depth, bufLen int) *bufPool { + return &bufPool{ + ch: make(chan []byte, depth), + blen: bufLen, + } +} + +func (p *bufPool) Get() []byte { + if p.blen <= 0 { + panic("bufPool: new buffer creation length must be greater than zero") + } + + for { + select { + case b := <-p.ch: + if cap(b) < p.blen { + // just in case: throw away any buffer with insufficient capacity. + continue + } + + return b[:p.blen] + + default: + return make([]byte, p.blen) + } + } +} + +func (p *bufPool) Put(b []byte) { + if p == nil { + // functional default: no reuse. + return + } + + if cap(b) < p.blen || cap(b) > p.blen*2 { + // DO NOT reuse buffers with insufficient capacity. + // This could cause panics when resizing to p.blen. + + // DO NOT reuse buffers with excessive capacity. + // This could cause memory leaks. + return + } + + select { + case p.ch <- b: + default: + } +} + +type resChanPool chan chan result + +func newResChanPool(depth int) resChanPool { + return make(chan chan result, depth) +} + +func (p resChanPool) Get() chan result { + select { + case ch := <-p: + return ch + default: + return make(chan result, 1) + } +} + +func (p resChanPool) Put(ch chan result) { + select { + case p <- ch: + default: + } +} diff --git a/sftp/release.go b/sftp/release.go new file mode 100644 index 0000000..b695528 --- /dev/null +++ b/sftp/release.go @@ -0,0 +1,5 @@ +// +build !debug + +package sftp + +func debug(fmt string, args ...interface{}) {} diff --git a/sftp/request-attrs.go b/sftp/request-attrs.go new file mode 100644 index 0000000..b5c95b4 --- /dev/null +++ b/sftp/request-attrs.go @@ -0,0 +1,63 @@ +package sftp + +// Methods on the Request object to make working with the Flags bitmasks and +// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write +// request and AttrFlags() and Attributes() when working with SetStat requests. +import "os" + +// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags +// (https://golang.org/pkg/os/#pkg-constants). +type FileOpenFlags struct { + Read, Write, Append, Creat, Trunc, Excl bool +} + +func newFileOpenFlags(flags uint32) FileOpenFlags { + return FileOpenFlags{ + Read: flags&sshFxfRead != 0, + Write: flags&sshFxfWrite != 0, + Append: flags&sshFxfAppend != 0, + Creat: flags&sshFxfCreat != 0, + Trunc: flags&sshFxfTrunc != 0, + Excl: flags&sshFxfExcl != 0, + } +} + +// Pflags converts the bitmap/uint32 from SFTP Open packet pflag values, +// into a FileOpenFlags struct with booleans set for flags set in bitmap. +func (r *Request) Pflags() FileOpenFlags { + return newFileOpenFlags(r.Flags) +} + +// FileAttrFlags that indicate whether SFTP file attributes were passed. When a flag is +// true the corresponding attribute should be available from the FileStat +// object returned by Attributes method. Used with SetStat. +type FileAttrFlags struct { + Size, UidGid, Permissions, Acmodtime bool +} + +func newFileAttrFlags(flags uint32) FileAttrFlags { + return FileAttrFlags{ + Size: (flags & sshFileXferAttrSize) != 0, + UidGid: (flags & sshFileXferAttrUIDGID) != 0, + Permissions: (flags & sshFileXferAttrPermissions) != 0, + Acmodtime: (flags & sshFileXferAttrACmodTime) != 0, + } +} + +// AttrFlags returns a FileAttrFlags boolean struct based on the +// bitmap/uint32 file attribute flags from the SFTP packaet. +func (r *Request) AttrFlags() FileAttrFlags { + return newFileAttrFlags(r.Flags) +} + +// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode +func (a FileStat) FileMode() os.FileMode { + return os.FileMode(a.Mode) +} + +// Attributes parses file attributes byte blob and return them in a +// FileStat object. +func (r *Request) Attributes() *FileStat { + fs, _ := unmarshalFileStat(r.Flags, r.Attrs) + return fs +} diff --git a/sftp/request-errors.go b/sftp/request-errors.go new file mode 100644 index 0000000..6505b5c --- /dev/null +++ b/sftp/request-errors.go @@ -0,0 +1,54 @@ +package sftp + +type fxerr uint32 + +// Error types that match the SFTP's SSH_FXP_STATUS codes. Gives you more +// direct control of the errors being sent vs. letting the library work them +// out from the standard os/io errors. +const ( + ErrSSHFxOk = fxerr(sshFxOk) + ErrSSHFxEOF = fxerr(sshFxEOF) + ErrSSHFxNoSuchFile = fxerr(sshFxNoSuchFile) + ErrSSHFxPermissionDenied = fxerr(sshFxPermissionDenied) + ErrSSHFxFailure = fxerr(sshFxFailure) + ErrSSHFxBadMessage = fxerr(sshFxBadMessage) + ErrSSHFxNoConnection = fxerr(sshFxNoConnection) + ErrSSHFxConnectionLost = fxerr(sshFxConnectionLost) + ErrSSHFxOpUnsupported = fxerr(sshFxOPUnsupported) +) + +// Deprecated error types, these are aliases for the new ones, please use the new ones directly +const ( + ErrSshFxOk = ErrSSHFxOk + ErrSshFxEof = ErrSSHFxEOF + ErrSshFxNoSuchFile = ErrSSHFxNoSuchFile + ErrSshFxPermissionDenied = ErrSSHFxPermissionDenied + ErrSshFxFailure = ErrSSHFxFailure + ErrSshFxBadMessage = ErrSSHFxBadMessage + ErrSshFxNoConnection = ErrSSHFxNoConnection + ErrSshFxConnectionLost = ErrSSHFxConnectionLost + ErrSshFxOpUnsupported = ErrSSHFxOpUnsupported +) + +func (e fxerr) Error() string { + switch e { + case ErrSSHFxOk: + return "OK" + case ErrSSHFxEOF: + return "EOF" + case ErrSSHFxNoSuchFile: + return "no such file" + case ErrSSHFxPermissionDenied: + return "permission denied" + case ErrSSHFxBadMessage: + return "bad message" + case ErrSSHFxNoConnection: + return "no connection" + case ErrSSHFxConnectionLost: + return "connection lost" + case ErrSSHFxOpUnsupported: + return "operation unsupported" + default: + return "failure" + } +} diff --git a/sftp/request-interfaces.go b/sftp/request-interfaces.go new file mode 100644 index 0000000..c8c424c --- /dev/null +++ b/sftp/request-interfaces.go @@ -0,0 +1,121 @@ +package sftp + +import ( + "io" + "os" +) + +// WriterAtReaderAt defines the interface to return when a file is to +// be opened for reading and writing +type WriterAtReaderAt interface { + io.WriterAt + io.ReaderAt +} + +// Interfaces are differentiated based on required returned values. +// All input arguments are to be pulled from Request (the only arg). + +// The Handler interfaces all take the Request object as its only argument. +// All the data you should need to handle the call are in the Request object. +// The request.Method attribute is initially the most important one as it +// determines which Handler gets called. + +// FileReader should return an io.ReaderAt for the filepath +// Note in cases of an error, the error text will be sent to the client. +// Called for Methods: Get +type FileReader interface { + Fileread(*Request) (io.ReaderAt, error) +} + +// FileWriter should return an io.WriterAt for the filepath. +// +// The request server code will call Close() on the returned io.WriterAt +// ojbect if an io.Closer type assertion succeeds. +// Note in cases of an error, the error text will be sent to the client. +// Note when receiving an Append flag it is important to not open files using +// O_APPEND if you plan to use WriteAt, as they conflict. +// Called for Methods: Put, Open +type FileWriter interface { + Filewrite(*Request) (io.WriterAt, error) +} + +// OpenFileWriter is a FileWriter that implements the generic OpenFile method. +// You need to implement this optional interface if you want to be able +// to read and write from/to the same handle. +// Called for Methods: Open +type OpenFileWriter interface { + FileWriter + OpenFile(*Request) (WriterAtReaderAt, error) +} + +// FileCmder should return an error +// Note in cases of an error, the error text will be sent to the client. +// Called for Methods: Setstat, Rename, Rmdir, Mkdir, Link, Symlink, Remove +type FileCmder interface { + Filecmd(*Request) error +} + +// PosixRenameFileCmder is a FileCmder that implements the PosixRename method. +// If this interface is implemented PosixRename requests will call it +// otherwise they will be handled in the same way as Rename +type PosixRenameFileCmder interface { + FileCmder + PosixRename(*Request) error +} + +// StatVFSFileCmder is a FileCmder that implements the StatVFS method. +// You need to implement this interface if you want to handle statvfs requests. +// Please also be sure that the statvfs@openssh.com extension is enabled +type StatVFSFileCmder interface { + FileCmder + StatVFS(*Request) (*StatVFS, error) +} + +// FileLister should return an object that fulfils the ListerAt interface +// Note in cases of an error, the error text will be sent to the client. +// Called for Methods: List, Stat, Readlink +type FileLister interface { + Filelist(*Request) (ListerAt, error) +} + +// LstatFileLister is a FileLister that implements the Lstat method. +// If this interface is implemented Lstat requests will call it +// otherwise they will be handled in the same way as Stat +type LstatFileLister interface { + FileLister + Lstat(*Request) (ListerAt, error) +} + +// RealPathFileLister is a FileLister that implements the Realpath method. +// We use "/" as start directory for relative paths, implementing this +// interface you can customize the start directory. +// You have to return an absolute POSIX path. +type RealPathFileLister interface { + FileLister + RealPath(string) string +} + +// NameLookupFileLister is a FileLister that implmeents the LookupUsername and LookupGroupName methods. +// If this interface is implemented, then longname ls formatting will use these to convert usernames and groupnames. +type NameLookupFileLister interface { + FileLister + LookupUserName(string) string + LookupGroupName(string) string +} + +// ListerAt does for file lists what io.ReaderAt does for files. +// ListAt should return the number of entries copied and an io.EOF +// error if at end of list. This is testable by comparing how many you +// copied to how many could be copied (eg. n < len(ls) below). +// The copy() builtin is best for the copying. +// Note in cases of an error, the error text will be sent to the client. +type ListerAt interface { + ListAt([]os.FileInfo, int64) (int, error) +} + +// TransferError is an optional interface that readerAt and writerAt +// can implement to be notified about the error causing Serve() to exit +// with the request still open +type TransferError interface { + TransferError(err error) +} diff --git a/sftp/request-plan9.go b/sftp/request-plan9.go new file mode 100644 index 0000000..2444da5 --- /dev/null +++ b/sftp/request-plan9.go @@ -0,0 +1,34 @@ +// +build plan9 + +package sftp + +import ( + "path" + "path/filepath" + "syscall" +) + +func fakeFileInfoSys() interface{} { + return &syscall.Dir{} +} + +func testOsSys(sys interface{}) error { + return nil +} + +func toLocalPath(p string) string { + lp := filepath.FromSlash(p) + + if path.IsAbs(p) { + tmp := lp[1:] + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes is absolute, + // then we have a filepath encoded with a prefix '/'. + // e.g. "/#s/boot" to "#s/boot" + return tmp + } + } + + return lp +} diff --git a/sftp/request-server.go b/sftp/request-server.go new file mode 100644 index 0000000..5fa828b --- /dev/null +++ b/sftp/request-server.go @@ -0,0 +1,304 @@ +package sftp + +import ( + "context" + "errors" + "io" + "path" + "path/filepath" + "strconv" + "sync" +) + +var maxTxPacket uint32 = 1 << 15 + +// Handlers contains the 4 SFTP server request handlers. +type Handlers struct { + FileGet FileReader + FilePut FileWriter + FileCmd FileCmder + FileList FileLister +} + +// RequestServer abstracts the sftp protocol with an http request-like protocol +type RequestServer struct { + Handlers Handlers + + *serverConn + pktMgr *packetManager + + mu sync.RWMutex + handleCount int + openRequests map[string]*Request +} + +// A RequestServerOption is a function which applies configuration to a RequestServer. +type RequestServerOption func(*RequestServer) + +// WithRSAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithRSAllocator() RequestServerOption { + return func(rs *RequestServer) { + alloc := newAllocator() + rs.pktMgr.alloc = alloc + rs.conn.alloc = alloc + } +} + +// NewRequestServer creates/allocates/returns new RequestServer. +// Normally there will be one server per user-session. +func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer { + svrConn := &serverConn{ + conn: conn{ + Reader: rwc, + WriteCloser: rwc, + }, + } + rs := &RequestServer{ + Handlers: h, + + serverConn: svrConn, + pktMgr: newPktMgr(svrConn), + + openRequests: make(map[string]*Request), + } + + for _, o := range options { + o(rs) + } + return rs +} + +// New Open packet/Request +func (rs *RequestServer) nextRequest(r *Request) string { + rs.mu.Lock() + defer rs.mu.Unlock() + + rs.handleCount++ + + r.handle = strconv.Itoa(rs.handleCount) + rs.openRequests[r.handle] = r + + return r.handle +} + +// Returns Request from openRequests, bool is false if it is missing. +// +// The Requests in openRequests work essentially as open file descriptors that +// you can do different things with. What you are doing with it are denoted by +// the first packet of that type (read/write/etc). +func (rs *RequestServer) getRequest(handle string) (*Request, bool) { + rs.mu.RLock() + defer rs.mu.RUnlock() + + r, ok := rs.openRequests[handle] + return r, ok +} + +// Close the Request and clear from openRequests map +func (rs *RequestServer) closeRequest(handle string) error { + rs.mu.Lock() + defer rs.mu.Unlock() + + if r, ok := rs.openRequests[handle]; ok { + delete(rs.openRequests, handle) + return r.close() + } + + return EBADF +} + +// Close the read/write/closer to trigger exiting the main server loop +func (rs *RequestServer) Close() error { return rs.conn.Close() } + +func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error { + defer close(pktChan) // shuts down sftpServerWorkers + + var err error + var pkt requestPacket + var pktType uint8 + var pktBytes []byte + + for { + pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID()) + if err != nil { + // we don't care about releasing allocated pages here, the server will quit and the allocator freed + return err + } + + pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + if err != nil { + switch { + case errors.Is(err, errUnknownExtendedPacket): + // do nothing + default: + debug("makePacket err: %v", err) + rs.conn.Close() // shuts down recvPacket + return err + } + } + + pktChan <- rs.pktMgr.newOrderedRequest(pkt) + } +} + +// Serve requests for user session +func (rs *RequestServer) Serve() error { + defer func() { + if rs.pktMgr.alloc != nil { + rs.pktMgr.alloc.Free() + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + runWorker := func(ch chan orderedRequest) { + wg.Add(1) + go func() { + defer wg.Done() + if err := rs.packetWorker(ctx, ch); err != nil { + rs.conn.Close() // shuts down recvPacket + } + }() + } + pktChan := rs.pktMgr.workerChan(runWorker) + + err := rs.serveLoop(pktChan) + + wg.Wait() // wait for all workers to exit + + rs.mu.Lock() + defer rs.mu.Unlock() + + // make sure all open requests are properly closed + // (eg. possible on dropped connections, client crashes, etc.) + for handle, req := range rs.openRequests { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + req.transferError(err) + + delete(rs.openRequests, handle) + req.close() + } + + return err +} + +func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error { + for pkt := range pktChan { + orderID := pkt.orderID() + if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok { + if epkt.SpecificPacket != nil { + pkt.requestPacket = epkt.SpecificPacket + } + } + + var rpkt responsePacket + switch pkt := pkt.requestPacket.(type) { + case *sshFxInitPacket: + rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions} + case *sshFxpClosePacket: + handle := pkt.getHandle() + rpkt = statusFromError(pkt.ID, rs.closeRequest(handle)) + case *sshFxpRealpathPacket: + var realPath string + if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok { + realPath = realPather.RealPath(pkt.getPath()) + } else { + realPath = cleanPath(pkt.getPath()) + } + rpkt = cleanPacketPath(pkt, realPath) + case *sshFxpOpendirPacket: + request := requestFromPacket(ctx, pkt) + handle := rs.nextRequest(request) + rpkt = request.opendir(rs.Handlers, pkt) + if _, ok := rpkt.(*sshFxpHandlePacket); !ok { + // if we return an error we have to remove the handle from the active ones + rs.closeRequest(handle) + } + case *sshFxpOpenPacket: + request := requestFromPacket(ctx, pkt) + handle := rs.nextRequest(request) + rpkt = request.open(rs.Handlers, pkt) + if _, ok := rpkt.(*sshFxpHandlePacket); !ok { + // if we return an error we have to remove the handle from the active ones + rs.closeRequest(handle) + } + case *sshFxpFstatPacket: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt.ID, EBADF) + } else { + request = NewRequest("Stat", request.Filepath) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + } + case *sshFxpFsetstatPacket: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt.ID, EBADF) + } else { + request = NewRequest("Setstat", request.Filepath) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + } + case *sshFxpExtendedPacketPosixRename: + request := NewRequest("PosixRename", pkt.Oldpath) + request.Target = pkt.Newpath + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + case *sshFxpExtendedPacketStatVFS: + request := NewRequest("StatVFS", pkt.Path) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + case hasHandle: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt.id(), EBADF) + } else { + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + } + case hasPath: + request := requestFromPacket(ctx, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) + request.close() + default: + rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported) + } + + rs.pktMgr.readyPacket( + rs.pktMgr.newOrderedResponse(rpkt, orderID)) + } + return nil +} + +// clean and return name packet for file +func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket { + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: []*sshFxpNameAttr{ + { + Name: realPath, + LongName: realPath, + Attrs: emptyFileStat, + }, + }, + } +} + +// Makes sure we have a clean POSIX (/) absolute path to work with +func cleanPath(p string) string { + return cleanPathWithBase("/", p) +} + +func cleanPathWithBase(base, p string) string { + p = filepath.ToSlash(filepath.Clean(p)) + if !path.IsAbs(p) { + return path.Join(base, p) + } + return p +} diff --git a/sftp/request-unix.go b/sftp/request-unix.go new file mode 100644 index 0000000..50b08a3 --- /dev/null +++ b/sftp/request-unix.go @@ -0,0 +1,27 @@ +// +build !windows,!plan9 + +package sftp + +import ( + "errors" + "syscall" +) + +func fakeFileInfoSys() interface{} { + return &syscall.Stat_t{Uid: 65534, Gid: 65534} +} + +func testOsSys(sys interface{}) error { + fstat := sys.(*FileStat) + if fstat.UID != uint32(65534) { + return errors.New("Uid failed to match") + } + if fstat.GID != uint32(65534) { + return errors.New("Gid failed to match") + } + return nil +} + +func toLocalPath(p string) string { + return p +} diff --git a/sftp/request.go b/sftp/request.go new file mode 100644 index 0000000..c6da4b6 --- /dev/null +++ b/sftp/request.go @@ -0,0 +1,628 @@ +package sftp + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + "sync" + "syscall" +) + +// MaxFilelist is the max number of files to return in a readdir batch. +var MaxFilelist int64 = 100 + +// state encapsulates the reader/writer/readdir from handlers. +type state struct { + mu sync.RWMutex + + writerAt io.WriterAt + readerAt io.ReaderAt + writerAtReaderAt WriterAtReaderAt + listerAt ListerAt + lsoffset int64 +} + +// copy returns a shallow copy the state. +// This is broken out to specific fields, +// because we have to copy around the mutex in state. +func (s *state) copy() state { + s.mu.RLock() + defer s.mu.RUnlock() + + return state{ + writerAt: s.writerAt, + readerAt: s.readerAt, + writerAtReaderAt: s.writerAtReaderAt, + listerAt: s.listerAt, + lsoffset: s.lsoffset, + } +} + +func (s *state) setReaderAt(rd io.ReaderAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.readerAt = rd +} + +func (s *state) getReaderAt() io.ReaderAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.readerAt +} + +func (s *state) setWriterAt(rd io.WriterAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.writerAt = rd +} + +func (s *state) getWriterAt() io.WriterAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.writerAt +} + +func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.writerAtReaderAt = rw +} + +func (s *state) getWriterAtReaderAt() WriterAtReaderAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.writerAtReaderAt +} + +func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.readerAt, s.writerAt, s.writerAtReaderAt +} + +// Returns current offset for file list +func (s *state) lsNext() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.lsoffset +} + +// Increases next offset +func (s *state) lsInc(offset int64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.lsoffset += offset +} + +// manage file read/write state +func (s *state) setListerAt(la ListerAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.listerAt = la +} + +func (s *state) getListerAt() ListerAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.listerAt +} + +// Request contains the data and state for the incoming service request. +type Request struct { + // Get, Put, Setstat, Stat, Rename, Remove + // Rmdir, Mkdir, List, Readlink, Link, Symlink + Method string + Filepath string + Flags uint32 + Attrs []byte // convert to sub-struct + Target string // for renames and sym-links + handle string + + // reader/writer/readdir from handlers + state + + // context lasts duration of request + ctx context.Context + cancelCtx context.CancelFunc +} + +// NewRequest creates a new Request object. +func NewRequest(method, path string) *Request { + return &Request{ + Method: method, + Filepath: cleanPath(path), + } +} + +// copy returns a shallow copy of existing request. +// This is broken out to specific fields, +// because we have to copy around the mutex in state. +func (r *Request) copy() *Request { + return &Request{ + Method: r.Method, + Filepath: r.Filepath, + Flags: r.Flags, + Attrs: r.Attrs, + Target: r.Target, + handle: r.handle, + + state: r.state.copy(), + + ctx: r.ctx, + cancelCtx: r.cancelCtx, + } +} + +// New Request initialized based on packet data +func requestFromPacket(ctx context.Context, pkt hasPath) *Request { + method := requestMethod(pkt) + request := NewRequest(method, pkt.getPath()) + request.ctx, request.cancelCtx = context.WithCancel(ctx) + + switch p := pkt.(type) { + case *sshFxpOpenPacket: + request.Flags = p.Pflags + case *sshFxpSetstatPacket: + request.Flags = p.Flags + request.Attrs = p.Attrs.([]byte) + case *sshFxpRenamePacket: + request.Target = cleanPath(p.Newpath) + case *sshFxpSymlinkPacket: + // NOTE: given a POSIX compliant signature: symlink(target, linkpath string) + // this makes Request.Target the linkpath, and Request.Filepath the target. + request.Target = cleanPath(p.Linkpath) + case *sshFxpExtendedPacketHardlink: + request.Target = cleanPath(p.Newpath) + } + return request +} + +// Context returns the request's context. To change the context, +// use WithContext. +// +// The returned context is always non-nil; it defaults to the +// background context. +// +// For incoming server requests, the context is canceled when the +// request is complete or the client's connection closes. +func (r *Request) Context() context.Context { + if r.ctx != nil { + return r.ctx + } + return context.Background() +} + +// WithContext returns a copy of r with its context changed to ctx. +// The provided ctx must be non-nil. +func (r *Request) WithContext(ctx context.Context) *Request { + if ctx == nil { + panic("nil context") + } + r2 := r.copy() + r2.ctx = ctx + r2.cancelCtx = nil + return r2 +} + +// Close reader/writer if possible +func (r *Request) close() error { + defer func() { + if r.cancelCtx != nil { + r.cancelCtx() + } + }() + + rd, wr, rw := r.getAllReaderWriters() + + var err error + + // Close errors on a Writer are far more likely to be the important one. + // As they can be information that there was a loss of data. + if c, ok := wr.(io.Closer); ok { + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + } + } + + if c, ok := rw.(io.Closer); ok { + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + + r.setWriterAtReaderAt(nil) + } + } + + if c, ok := rd.(io.Closer); ok { + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + } + } + + return err +} + +// Notify transfer error if any +func (r *Request) transferError(err error) { + if err == nil { + return + } + + rd, wr, rw := r.getAllReaderWriters() + + if t, ok := wr.(TransferError); ok { + t.TransferError(err) + } + + if t, ok := rw.(TransferError); ok { + t.TransferError(err) + } + + if t, ok := rd.(TransferError); ok { + t.TransferError(err) + } +} + +// called from worker to handle packet/request +func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { + switch r.Method { + case "Get": + return fileget(handlers.FileGet, r, pkt, alloc, orderID) + case "Put": + return fileput(handlers.FilePut, r, pkt, alloc, orderID) + case "Open": + return fileputget(handlers.FilePut, r, pkt, alloc, orderID) + case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS": + return filecmd(handlers.FileCmd, r, pkt) + case "List": + return filelist(handlers.FileList, r, pkt) + case "Stat", "Lstat", "Readlink": + return filestat(handlers.FileList, r, pkt) + default: + return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method)) + } +} + +// Additional initialization for Open packets +func (r *Request) open(h Handlers, pkt requestPacket) responsePacket { + flags := r.Pflags() + + id := pkt.id() + + switch { + case flags.Write, flags.Append, flags.Creat, flags.Trunc: + if flags.Read { + if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok { + r.Method = "Open" + rw, err := openFileWriter.OpenFile(r) + if err != nil { + return statusFromError(id, err) + } + + r.setWriterAtReaderAt(rw) + + return &sshFxpHandlePacket{ + ID: id, + Handle: r.handle, + } + } + } + + r.Method = "Put" + wr, err := h.FilePut.Filewrite(r) + if err != nil { + return statusFromError(id, err) + } + + r.setWriterAt(wr) + + case flags.Read: + r.Method = "Get" + rd, err := h.FileGet.Fileread(r) + if err != nil { + return statusFromError(id, err) + } + + r.setReaderAt(rd) + + default: + return statusFromError(id, errors.New("bad file flags")) + } + + return &sshFxpHandlePacket{ + ID: id, + Handle: r.handle, + } +} + +func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { + r.Method = "List" + la, err := h.FileList.Filelist(r) + if err != nil { + return statusFromError(pkt.id(), wrapPathError(r.Filepath, err)) + } + + r.setListerAt(la) + + return &sshFxpHandlePacket{ + ID: pkt.id(), + Handle: r.handle, + } +} + +// wrap FileReader handler +func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { + rd := r.getReaderAt() + if rd == nil { + return statusFromError(pkt.id(), errors.New("unexpected read packet")) + } + + data, offset, _ := packetData(pkt, alloc, orderID) + + n, err := rd.ReadAt(data, offset) + // only return EOF error if no data left to read + if err != nil && (err != io.EOF || n == 0) { + return statusFromError(pkt.id(), err) + } + + return &sshFxpDataPacket{ + ID: pkt.id(), + Length: uint32(n), + Data: data[:n], + } +} + +// wrap FileWriter handler +func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { + wr := r.getWriterAt() + if wr == nil { + return statusFromError(pkt.id(), errors.New("unexpected write packet")) + } + + data, offset, _ := packetData(pkt, alloc, orderID) + + _, err := wr.WriteAt(data, offset) + return statusFromError(pkt.id(), err) +} + +// wrap OpenFileWriter handler +func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { + rw := r.getWriterAtReaderAt() + if rw == nil { + return statusFromError(pkt.id(), errors.New("unexpected write and read packet")) + } + + switch p := pkt.(type) { + case *sshFxpReadPacket: + data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset) + + n, err := rw.ReadAt(data, offset) + // only return EOF error if no data left to read + if err != nil && (err != io.EOF || n == 0) { + return statusFromError(pkt.id(), err) + } + + return &sshFxpDataPacket{ + ID: pkt.id(), + Length: uint32(n), + Data: data[:n], + } + + case *sshFxpWritePacket: + data, offset := p.Data, int64(p.Offset) + + _, err := rw.WriteAt(data, offset) + return statusFromError(pkt.id(), err) + + default: + return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write")) + } +} + +// file data for additional read/write packets +func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) { + switch p := p.(type) { + case *sshFxpReadPacket: + return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len + case *sshFxpWritePacket: + return p.Data, int64(p.Offset), p.Length + } + return +} + +// wrap FileCmder handler +func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket { + switch p := pkt.(type) { + case *sshFxpFsetstatPacket: + r.Flags = p.Flags + r.Attrs = p.Attrs.([]byte) + } + + switch r.Method { + case "PosixRename": + if posixRenamer, ok := h.(PosixRenameFileCmder); ok { + err := posixRenamer.PosixRename(r) + return statusFromError(pkt.id(), err) + } + + // PosixRenameFileCmder not implemented handle this request as a Rename + r.Method = "Rename" + err := h.Filecmd(r) + return statusFromError(pkt.id(), err) + + case "StatVFS": + if statVFSCmdr, ok := h.(StatVFSFileCmder); ok { + stat, err := statVFSCmdr.StatVFS(r) + if err != nil { + return statusFromError(pkt.id(), err) + } + stat.ID = pkt.id() + return stat + } + + return statusFromError(pkt.id(), ErrSSHFxOpUnsupported) + } + + err := h.Filecmd(r) + return statusFromError(pkt.id(), err) +} + +// wrap FileLister handler +func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { + lister := r.getListerAt() + if lister == nil { + return statusFromError(pkt.id(), errors.New("unexpected dir packet")) + } + + offset := r.lsNext() + finfo := make([]os.FileInfo, MaxFilelist) + n, err := lister.ListAt(finfo, offset) + r.lsInc(int64(n)) + // ignore EOF as we only return it when there are no results + finfo = finfo[:n] // avoid need for nil tests below + + switch r.Method { + case "List": + if err != nil && (err != io.EOF || n == 0) { + return statusFromError(pkt.id(), err) + } + + nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo)) + + // If the type conversion fails, we get untyped `nil`, + // which is handled by not looking up any names. + idLookup, _ := h.(NameLookupFileLister) + + for _, fi := range finfo { + nameAttrs = append(nameAttrs, &sshFxpNameAttr{ + Name: fi.Name(), + LongName: runLs(idLookup, fi), + Attrs: []interface{}{fi}, + }) + } + + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: nameAttrs, + } + + default: + err = fmt.Errorf("unexpected method: %s", r.Method) + return statusFromError(pkt.id(), err) + } +} + +func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { + var lister ListerAt + var err error + + if r.Method == "Lstat" { + if lstatFileLister, ok := h.(LstatFileLister); ok { + lister, err = lstatFileLister.Lstat(r) + } else { + // LstatFileLister not implemented handle this request as a Stat + r.Method = "Stat" + lister, err = h.Filelist(r) + } + } else { + lister, err = h.Filelist(r) + } + if err != nil { + return statusFromError(pkt.id(), err) + } + finfo := make([]os.FileInfo, 1) + n, err := lister.ListAt(finfo, 0) + finfo = finfo[:n] // avoid need for nil tests below + + switch r.Method { + case "Stat", "Lstat": + if err != nil && err != io.EOF { + return statusFromError(pkt.id(), err) + } + if n == 0 { + err = &os.PathError{ + Op: strings.ToLower(r.Method), + Path: r.Filepath, + Err: syscall.ENOENT, + } + return statusFromError(pkt.id(), err) + } + return &sshFxpStatResponse{ + ID: pkt.id(), + info: finfo[0], + } + case "Readlink": + if err != nil && err != io.EOF { + return statusFromError(pkt.id(), err) + } + if n == 0 { + err = &os.PathError{ + Op: "readlink", + Path: r.Filepath, + Err: syscall.ENOENT, + } + return statusFromError(pkt.id(), err) + } + filename := finfo[0].Name() + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: []*sshFxpNameAttr{ + { + Name: filename, + LongName: filename, + Attrs: emptyFileStat, + }, + }, + } + default: + err = fmt.Errorf("unexpected method: %s", r.Method) + return statusFromError(pkt.id(), err) + } +} + +// init attributes of request object from packet data +func requestMethod(p requestPacket) (method string) { + switch p.(type) { + case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket: + // set in open() above + case *sshFxpOpendirPacket, *sshFxpReaddirPacket: + // set in opendir() above + case *sshFxpSetstatPacket, *sshFxpFsetstatPacket: + method = "Setstat" + case *sshFxpRenamePacket: + method = "Rename" + case *sshFxpSymlinkPacket: + method = "Symlink" + case *sshFxpRemovePacket: + method = "Remove" + case *sshFxpStatPacket, *sshFxpFstatPacket: + method = "Stat" + case *sshFxpLstatPacket: + method = "Lstat" + case *sshFxpRmdirPacket: + method = "Rmdir" + case *sshFxpReadlinkPacket: + method = "Readlink" + case *sshFxpMkdirPacket: + method = "Mkdir" + case *sshFxpExtendedPacketHardlink: + method = "Link" + } + return method +} diff --git a/sftp/request_windows.go b/sftp/request_windows.go new file mode 100644 index 0000000..1f6d3df --- /dev/null +++ b/sftp/request_windows.go @@ -0,0 +1,44 @@ +package sftp + +import ( + "path" + "path/filepath" + "syscall" +) + +func fakeFileInfoSys() interface{} { + return syscall.Win32FileAttributeData{} +} + +func testOsSys(sys interface{}) error { + return nil +} + +func toLocalPath(p string) string { + lp := filepath.FromSlash(p) + + if path.IsAbs(p) { + tmp := lp + for len(tmp) > 0 && tmp[0] == '\\' { + tmp = tmp[1:] + } + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes is absolute, + // then we have a filepath encoded with a prefix '/'. + // e.g. "/C:/Windows" to "C:\\Windows" + return tmp + } + + tmp += "\\" + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes but with extra end slash is absolute, + // then we have a filepath encoded with a prefix '/' and a dropped '/' at the end. + // e.g. "/C:" to "C:\\" + return tmp + } + } + + return lp +} diff --git a/sftp/server.go b/sftp/server.go new file mode 100644 index 0000000..6b2b20f --- /dev/null +++ b/sftp/server.go @@ -0,0 +1,643 @@ +package sftp + +// sftp server counterpart + +import ( + "context" + "log" + "encoding" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "strconv" + "sync" + "syscall" + "time" + + "git.deuxfleurs.fr/Deuxfleurs/bagage/s3" +) + +const ( + // SftpServerWorkerCount defines the number of workers for the SFTP server + SftpServerWorkerCount = 8 +) + +// Server is an SSH File Transfer Protocol (sftp) server. +// This is intended to provide the sftp subsystem to an ssh server daemon. +// This implementation currently supports most of sftp server protocol version 3, +// as specified at http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 +type Server struct { + *serverConn + debugStream io.Writer + readOnly bool + pktMgr *packetManager + openFiles map[string]*s3.S3File + openFilesLock sync.RWMutex + handleCount int + fs *s3.S3FS + ctx context.Context +} + +func (svr *Server) nextHandle(f *s3.S3File) string { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + svr.handleCount++ + handle := strconv.Itoa(svr.handleCount) + svr.openFiles[handle] = f + return handle +} + +func (svr *Server) closeHandle(handle string) error { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + if f, ok := svr.openFiles[handle]; ok { + delete(svr.openFiles, handle) + return f.Close() + } + + return EBADF +} + +func (svr *Server) getHandle(handle string) (*s3.S3File, bool) { + svr.openFilesLock.RLock() + defer svr.openFilesLock.RUnlock() + f, ok := svr.openFiles[handle] + return f, ok +} + +type serverRespondablePacket interface { + encoding.BinaryUnmarshaler + id() uint32 + respond(svr *Server) responsePacket +} + +// NewServer creates a new Server instance around the provided streams, serving +// content from the root of the filesystem. Optionally, ServerOption +// functions may be specified to further configure the Server. +// +// A subsequent call to Serve() is required to begin serving files over SFTP. +func NewServer(ctx context.Context, rwc io.ReadWriteCloser, fs *s3.S3FS, options ...ServerOption) (*Server, error) { + svrConn := &serverConn{ + conn: conn{ + Reader: rwc, + WriteCloser: rwc, + }, + } + s := &Server{ + serverConn: svrConn, + debugStream: ioutil.Discard, + pktMgr: newPktMgr(svrConn), + openFiles: make(map[string]*s3.S3File), + fs: fs, + ctx: ctx, + } + + for _, o := range options { + if err := o(s); err != nil { + return nil, err + } + } + + return s, nil +} + +// A ServerOption is a function which applies configuration to a Server. +type ServerOption func(*Server) error + +// WithDebug enables Server debugging output to the supplied io.Writer. +func WithDebug(w io.Writer) ServerOption { + return func(s *Server) error { + s.debugStream = w + return nil + } +} + +// ReadOnly configures a Server to serve files in read-only mode. +func ReadOnly() ServerOption { + return func(s *Server) error { + s.readOnly = true + return nil + } +} + +// WithAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithAllocator() ServerOption { + return func(s *Server) error { + alloc := newAllocator() + s.pktMgr.alloc = alloc + s.conn.alloc = alloc + return nil + } +} + +type rxPacket struct { + pktType fxp + pktBytes []byte +} + +// Up to N parallel servers +func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { + for pkt := range pktChan { + // readonly checks + readonly := true + switch pkt := pkt.requestPacket.(type) { + case notReadOnly: + readonly = false + case *sshFxpOpenPacket: + readonly = pkt.readonly() + case *sshFxpExtendedPacket: + readonly = pkt.readonly() + } + + // If server is operating read-only and a write operation is requested, + // return permission denied + if !readonly && svr.readOnly { + svr.pktMgr.readyPacket( + svr.pktMgr.newOrderedResponse(statusFromError(pkt.id(), syscall.EPERM), pkt.orderID()), + ) + continue + } + + if err := handlePacket(svr, pkt); err != nil { + return err + } + } + return nil +} + +func handlePacket(s *Server, p orderedRequest) error { + var rpkt responsePacket + orderID := p.orderID() + switch p := p.requestPacket.(type) { + case *sshFxInitPacket: + log.Println("pkt: init") + rpkt = &sshFxVersionPacket{ + Version: sftpProtocolVersion, + Extensions: sftpExtensions, + } + case *sshFxpStatPacket: + log.Println("pkt: stat: ", p.Path) + // stat the requested file + info, err := os.Stat(toLocalPath(p.Path)) + rpkt = &sshFxpStatResponse{ + ID: p.ID, + info: info, + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpLstatPacket: + log.Println("pkt: lstat: ", p.Path) + // stat the requested file + info, err := os.Lstat(toLocalPath(p.Path)) + rpkt = &sshFxpStatResponse{ + ID: p.ID, + info: info, + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpFstatPacket: + log.Println("pkt: fstat: ", p.Handle) + f, ok := s.getHandle(p.Handle) + var err error = EBADF + var info os.FileInfo + if ok { + info, err = f.Stat() + rpkt = &sshFxpStatResponse{ + ID: p.ID, + info: info, + } + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpMkdirPacket: + log.Println("pkt: mkdir: ", p.Path) + err := os.Mkdir(toLocalPath(p.Path), 0755) + rpkt = statusFromError(p.ID, err) + case *sshFxpRmdirPacket: + log.Println("pkt: rmdir: ", p.Path) + err := os.Remove(toLocalPath(p.Path)) + rpkt = statusFromError(p.ID, err) + case *sshFxpRemovePacket: + log.Println("pkt: rm: ", p.Filename) + err := os.Remove(toLocalPath(p.Filename)) + rpkt = statusFromError(p.ID, err) + case *sshFxpRenamePacket: + log.Println("pkt: rename: ", p.Oldpath, ", ", p.Newpath) + err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath)) + rpkt = statusFromError(p.ID, err) + case *sshFxpSymlinkPacket: + log.Println("pkt: ln -s: ", p.Targetpath, ", ", p.Linkpath) + err := os.Symlink(toLocalPath(p.Targetpath), toLocalPath(p.Linkpath)) + rpkt = statusFromError(p.ID, err) + case *sshFxpClosePacket: + log.Println("pkt: close handle: ", p.Handle) + rpkt = statusFromError(p.ID, s.closeHandle(p.Handle)) + case *sshFxpReadlinkPacket: + log.Println("pkt: read: ", p.Path) + f, err := os.Readlink(toLocalPath(p.Path)) + rpkt = &sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []*sshFxpNameAttr{ + { + Name: f, + LongName: f, + Attrs: emptyFileStat, + }, + }, + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpRealpathPacket: + log.Println("pkt: absolute path: ", p.Path) + f := s3.NewS3Path(p.Path).Path + rpkt = &sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []*sshFxpNameAttr{ + { + Name: f, + LongName: f, + Attrs: emptyFileStat, + }, + }, + } + case *sshFxpOpendirPacket: + log.Println("pkt: open dir: ", p.Path) + p.Path = s3.NewS3Path(p.Path).Path + + if stat, err := s.fs.Stat(s.ctx, p.Path); err != nil { + rpkt = statusFromError(p.ID, err) + } else if !stat.IsDir() { + rpkt = statusFromError(p.ID, &os.PathError{ + Path: p.Path, Err: syscall.ENOTDIR}) + } else { + rpkt = (&sshFxpOpenPacket{ + ID: p.ID, + Path: p.Path, + Pflags: sshFxfRead, + }).respond(s) + } + case *sshFxpReadPacket: + log.Println("pkt: read handle: ", p.Handle) + var err error = EBADF + f, ok := s.getHandle(p.Handle) + if ok { + err = nil + data := p.getDataSlice(s.pktMgr.alloc, orderID) + n, _err := f.ReadAt(data, int64(p.Offset)) + if _err != nil && (_err != io.EOF || n == 0) { + err = _err + } + rpkt = &sshFxpDataPacket{ + ID: p.ID, + Length: uint32(n), + Data: data[:n], + // do not use data[:n:n] here to clamp the capacity, we allocated extra capacity above to avoid reallocations + } + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + + case *sshFxpWritePacket: + log.Println("pkt: write handle: ", p.Handle, ", Offset: ", p.Offset) + f, ok := s.getHandle(p.Handle) + var err error = EBADF + if ok { + _, err = f.WriteAt(p.Data, int64(p.Offset)) + } + rpkt = statusFromError(p.ID, err) + case *sshFxpExtendedPacket: + log.Println("pkt: extended packet") + if p.SpecificPacket == nil { + rpkt = statusFromError(p.ID, ErrSSHFxOpUnsupported) + } else { + rpkt = p.respond(s) + } + case serverRespondablePacket: + log.Println("pkt: respondable") + rpkt = p.respond(s) + default: + return fmt.Errorf("unexpected packet type %T", p) + } + + s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, orderID)) + return nil +} + +// Serve serves SFTP connections until the streams stop or the SFTP subsystem +// is stopped. +func (svr *Server) Serve() error { + defer func() { + if svr.pktMgr.alloc != nil { + svr.pktMgr.alloc.Free() + } + }() + var wg sync.WaitGroup + runWorker := func(ch chan orderedRequest) { + wg.Add(1) + go func() { + defer wg.Done() + if err := svr.sftpServerWorker(ch); err != nil { + svr.conn.Close() // shuts down recvPacket + } + }() + } + pktChan := svr.pktMgr.workerChan(runWorker) + + var err error + var pkt requestPacket + var pktType uint8 + var pktBytes []byte + for { + pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID()) + if err != nil { + // we don't care about releasing allocated pages here, the server will quit and the allocator freed + break + } + + pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) + if err != nil { + switch { + case errors.Is(err, errUnknownExtendedPacket): + //if err := svr.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil { + // debug("failed to send err packet: %v", err) + // svr.conn.Close() // shuts down recvPacket + // break + //} + default: + debug("makePacket err: %v", err) + svr.conn.Close() // shuts down recvPacket + break + } + } + + pktChan <- svr.pktMgr.newOrderedRequest(pkt) + } + + close(pktChan) // shuts down sftpServerWorkers + wg.Wait() // wait for all workers to exit + + // close any still-open files + for handle, file := range svr.openFiles { + fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Path.Path) + file.Close() + } + return err // error from recvPacket +} + +type ider interface { + id() uint32 +} + +// The init packet has no ID, so we just return a zero-value ID +func (p *sshFxInitPacket) id() uint32 { return 0 } + +type sshFxpStatResponse struct { + ID uint32 + info os.FileInfo +} + +func (p *sshFxpStatResponse) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(id) + + b := make([]byte, 4, l) + b = append(b, sshFxpAttrs) + b = marshalUint32(b, p.ID) + + var payload []byte + payload = marshalFileInfo(payload, p.info) + + return b, payload, nil +} + +func (p *sshFxpStatResponse) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +var emptyFileStat = []interface{}{uint32(0)} + +func (p *sshFxpOpenPacket) readonly() bool { + return !p.hasPflags(sshFxfWrite) +} + +func (p *sshFxpOpenPacket) hasPflags(flags ...uint32) bool { + for _, f := range flags { + if p.Pflags&f == 0 { + return false + } + } + return true +} + +func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { + log.Println("pkt: open: ", p.Path) + var osFlags int + if p.hasPflags(sshFxfRead, sshFxfWrite) { + osFlags |= os.O_RDWR + } else if p.hasPflags(sshFxfWrite) { + osFlags |= os.O_WRONLY + } else if p.hasPflags(sshFxfRead) { + osFlags |= os.O_RDONLY + } else { + // how are they opening? + return statusFromError(p.ID, syscall.EINVAL) + } + + // Don't use O_APPEND flag as it conflicts with WriteAt. + // The sshFxfAppend flag is a no-op here as the client sends the offsets. + // @FIXME these flags are currently ignored + if p.hasPflags(sshFxfCreat) { + osFlags |= os.O_CREATE + } + if p.hasPflags(sshFxfTrunc) { + osFlags |= os.O_TRUNC + } + if p.hasPflags(sshFxfExcl) { + osFlags |= os.O_EXCL + } + + f, err := svr.fs.OpenFile2(svr.ctx, p.Path, osFlags, 0644) + if err != nil { + return statusFromError(p.ID, err) + } + + handle := svr.nextHandle(f) + return &sshFxpHandlePacket{ID: p.ID, Handle: handle} +} + +func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { + log.Println("pkt: readdir: ", p.Handle) + f, ok := svr.getHandle(p.Handle) + if !ok { + return statusFromError(p.ID, EBADF) + } + + dirents, err := f.Readdir(128) + if err != nil { + return statusFromError(p.ID, err) + } + + idLookup := osIDLookup{} + + ret := &sshFxpNamePacket{ID: p.ID} + for _, dirent := range dirents { + ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{ + Name: dirent.Name(), + LongName: runLs(idLookup, dirent), + Attrs: []interface{}{dirent}, + }) + } + return ret +} + +func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { + log.Println("pkt: setstat: ", p.Path) + // additional unmarshalling is required for each possibility here + b := p.Attrs.([]byte) + var err error + + p.Path = toLocalPath(p.Path) + + debug("setstat name \"%s\"", p.Path) + if (p.Flags & sshFileXferAttrSize) != 0 { + var size uint64 + if size, b, err = unmarshalUint64Safe(b); err == nil { + err = os.Truncate(p.Path, int64(size)) + } + } + if (p.Flags & sshFileXferAttrPermissions) != 0 { + var mode uint32 + if mode, b, err = unmarshalUint32Safe(b); err == nil { + err = os.Chmod(p.Path, os.FileMode(mode)) + } + } + if (p.Flags & sshFileXferAttrACmodTime) != 0 { + var atime uint32 + var mtime uint32 + if atime, b, err = unmarshalUint32Safe(b); err != nil { + } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { + } else { + atimeT := time.Unix(int64(atime), 0) + mtimeT := time.Unix(int64(mtime), 0) + err = os.Chtimes(p.Path, atimeT, mtimeT) + } + } + if (p.Flags & sshFileXferAttrUIDGID) != 0 { + var uid uint32 + var gid uint32 + if uid, b, err = unmarshalUint32Safe(b); err != nil { + } else if gid, _, err = unmarshalUint32Safe(b); err != nil { + } else { + err = os.Chown(p.Path, int(uid), int(gid)) + } + } + + return statusFromError(p.ID, err) +} + +func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { + log.Println("pkt: fsetstat: ", p.Handle) + f, ok := svr.getHandle(p.Handle) + if !ok { + return statusFromError(p.ID, EBADF) + } + + // additional unmarshalling is required for each possibility here + //b := p.Attrs.([]byte) + var err error + + debug("fsetstat name \"%s\"", f.Path.Path) + if (p.Flags & sshFileXferAttrSize) != 0 { + /*var size uint64 + if size, b, err = unmarshalUint64Safe(b); err == nil { + err = f.Truncate(int64(size)) + }*/ + log.Println("WARN: changing size of the file is not supported") + } + if (p.Flags & sshFileXferAttrPermissions) != 0 { + /*var mode uint32 + if mode, b, err = unmarshalUint32Safe(b); err == nil { + err = f.Chmod(os.FileMode(mode)) + }*/ + log.Println("WARN: chmod not supported") + } + if (p.Flags & sshFileXferAttrACmodTime) != 0 { + /*var atime uint32 + var mtime uint32 + if atime, b, err = unmarshalUint32Safe(b); err != nil { + } else if mtime, b, err = unmarshalUint32Safe(b); err != nil { + } else { + atimeT := time.Unix(int64(atime), 0) + mtimeT := time.Unix(int64(mtime), 0) + err = os.Chtimes(f.Name(), atimeT, mtimeT) + }*/ + log.Println("WARN: chtimes not supported") + } + if (p.Flags & sshFileXferAttrUIDGID) != 0 { + /*var uid uint32 + var gid uint32 + if uid, b, err = unmarshalUint32Safe(b); err != nil { + } else if gid, _, err = unmarshalUint32Safe(b); err != nil { + } else { + err = f.Chown(int(uid), int(gid)) + }*/ + log.Println("WARN: chown not supported") + } + + return statusFromError(p.ID, err) +} + +func statusFromError(id uint32, err error) *sshFxpStatusPacket { + ret := &sshFxpStatusPacket{ + ID: id, + StatusError: StatusError{ + // sshFXOk = 0 + // sshFXEOF = 1 + // sshFXNoSuchFile = 2 ENOENT + // sshFXPermissionDenied = 3 + // sshFXFailure = 4 + // sshFXBadMessage = 5 + // sshFXNoConnection = 6 + // sshFXConnectionLost = 7 + // sshFXOPUnsupported = 8 + Code: sshFxOk, + }, + } + if err == nil { + return ret + } + + debug("statusFromError: error is %T %#v", err, err) + ret.StatusError.Code = sshFxFailure + ret.StatusError.msg = err.Error() + + if os.IsNotExist(err) { + ret.StatusError.Code = sshFxNoSuchFile + return ret + } + if code, ok := translateSyscallError(err); ok { + ret.StatusError.Code = code + return ret + } + + switch e := err.(type) { + case fxerr: + ret.StatusError.Code = uint32(e) + default: + if e == io.EOF { + ret.StatusError.Code = sshFxEOF + } + } + + return ret +} diff --git a/sftp/server_statvfs_darwin.go b/sftp/server_statvfs_darwin.go new file mode 100644 index 0000000..8c01dac --- /dev/null +++ b/sftp/server_statvfs_darwin.go @@ -0,0 +1,21 @@ +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Bsize), // fragment size is a linux thing; use block size here + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Fsid: uint64(uint64(stat.Fsid.Val[1])<<32 | uint64(stat.Fsid.Val[0])), // endianness? + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: 1024, // man 2 statfs shows: #define MAXPATHLEN 1024 + }, nil +} diff --git a/sftp/server_statvfs_impl.go b/sftp/server_statvfs_impl.go new file mode 100644 index 0000000..94b6d83 --- /dev/null +++ b/sftp/server_statvfs_impl.go @@ -0,0 +1,29 @@ +// +build darwin linux + +// fill in statvfs structure with OS specific values +// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance) + +package sftp + +import ( + "syscall" +) + +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + retPkt, err := getStatVFSForPath(p.Path) + if err != nil { + return statusFromError(p.ID, err) + } + retPkt.ID = p.ID + + return retPkt +} + +func getStatVFSForPath(name string) (*StatVFS, error) { + var stat syscall.Statfs_t + if err := syscall.Statfs(name, &stat); err != nil { + return nil, err + } + + return statvfsFromStatfst(&stat) +} diff --git a/sftp/server_statvfs_linux.go b/sftp/server_statvfs_linux.go new file mode 100644 index 0000000..1d180d4 --- /dev/null +++ b/sftp/server_statvfs_linux.go @@ -0,0 +1,22 @@ +// +build linux + +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Frsize), + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: uint64(stat.Namelen), + }, nil +} diff --git a/sftp/server_statvfs_plan9.go b/sftp/server_statvfs_plan9.go new file mode 100644 index 0000000..e71a27d --- /dev/null +++ b/sftp/server_statvfs_plan9.go @@ -0,0 +1,13 @@ +package sftp + +import ( + "syscall" +) + +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + return statusFromError(p.ID, syscall.EPLAN9) +} + +func getStatVFSForPath(name string) (*StatVFS, error) { + return nil, syscall.EPLAN9 +} diff --git a/sftp/server_statvfs_stubs.go b/sftp/server_statvfs_stubs.go new file mode 100644 index 0000000..fbf4906 --- /dev/null +++ b/sftp/server_statvfs_stubs.go @@ -0,0 +1,15 @@ +// +build !darwin,!linux,!plan9 + +package sftp + +import ( + "syscall" +) + +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + return statusFromError(p.ID, syscall.ENOTSUP) +} + +func getStatVFSForPath(name string) (*StatVFS, error) { + return nil, syscall.ENOTSUP +} diff --git a/sftp/sftp.go b/sftp/sftp.go new file mode 100644 index 0000000..9a63c39 --- /dev/null +++ b/sftp/sftp.go @@ -0,0 +1,258 @@ +// Package sftp implements the SSH File Transfer Protocol as described in +// https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 +package sftp + +import ( + "fmt" +) + +const ( + sshFxpInit = 1 + sshFxpVersion = 2 + sshFxpOpen = 3 + sshFxpClose = 4 + sshFxpRead = 5 + sshFxpWrite = 6 + sshFxpLstat = 7 + sshFxpFstat = 8 + sshFxpSetstat = 9 + sshFxpFsetstat = 10 + sshFxpOpendir = 11 + sshFxpReaddir = 12 + sshFxpRemove = 13 + sshFxpMkdir = 14 + sshFxpRmdir = 15 + sshFxpRealpath = 16 + sshFxpStat = 17 + sshFxpRename = 18 + sshFxpReadlink = 19 + sshFxpSymlink = 20 + sshFxpStatus = 101 + sshFxpHandle = 102 + sshFxpData = 103 + sshFxpName = 104 + sshFxpAttrs = 105 + sshFxpExtended = 200 + sshFxpExtendedReply = 201 +) + +const ( + sshFxOk = 0 + sshFxEOF = 1 + sshFxNoSuchFile = 2 + sshFxPermissionDenied = 3 + sshFxFailure = 4 + sshFxBadMessage = 5 + sshFxNoConnection = 6 + sshFxConnectionLost = 7 + sshFxOPUnsupported = 8 + + // see draft-ietf-secsh-filexfer-13 + // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1 + sshFxInvalidHandle = 9 + sshFxNoSuchPath = 10 + sshFxFileAlreadyExists = 11 + sshFxWriteProtect = 12 + sshFxNoMedia = 13 + sshFxNoSpaceOnFilesystem = 14 + sshFxQuotaExceeded = 15 + sshFxUnknownPrincipal = 16 + sshFxLockConflict = 17 + sshFxDirNotEmpty = 18 + sshFxNotADirectory = 19 + sshFxInvalidFilename = 20 + sshFxLinkLoop = 21 + sshFxCannotDelete = 22 + sshFxInvalidParameter = 23 + sshFxFileIsADirectory = 24 + sshFxByteRangeLockConflict = 25 + sshFxByteRangeLockRefused = 26 + sshFxDeletePending = 27 + sshFxFileCorrupt = 28 + sshFxOwnerInvalid = 29 + sshFxGroupInvalid = 30 + sshFxNoMatchingByteRangeLock = 31 +) + +const ( + sshFxfRead = 0x00000001 + sshFxfWrite = 0x00000002 + sshFxfAppend = 0x00000004 + sshFxfCreat = 0x00000008 + sshFxfTrunc = 0x00000010 + sshFxfExcl = 0x00000020 +) + +var ( + // supportedSFTPExtensions defines the supported extensions + supportedSFTPExtensions = []sshExtensionPair{ + {"hardlink@openssh.com", "1"}, + {"posix-rename@openssh.com", "1"}, + {"statvfs@openssh.com", "2"}, + } + sftpExtensions = supportedSFTPExtensions +) + +type fxp uint8 + +func (f fxp) String() string { + switch f { + case sshFxpInit: + return "SSH_FXP_INIT" + case sshFxpVersion: + return "SSH_FXP_VERSION" + case sshFxpOpen: + return "SSH_FXP_OPEN" + case sshFxpClose: + return "SSH_FXP_CLOSE" + case sshFxpRead: + return "SSH_FXP_READ" + case sshFxpWrite: + return "SSH_FXP_WRITE" + case sshFxpLstat: + return "SSH_FXP_LSTAT" + case sshFxpFstat: + return "SSH_FXP_FSTAT" + case sshFxpSetstat: + return "SSH_FXP_SETSTAT" + case sshFxpFsetstat: + return "SSH_FXP_FSETSTAT" + case sshFxpOpendir: + return "SSH_FXP_OPENDIR" + case sshFxpReaddir: + return "SSH_FXP_READDIR" + case sshFxpRemove: + return "SSH_FXP_REMOVE" + case sshFxpMkdir: + return "SSH_FXP_MKDIR" + case sshFxpRmdir: + return "SSH_FXP_RMDIR" + case sshFxpRealpath: + return "SSH_FXP_REALPATH" + case sshFxpStat: + return "SSH_FXP_STAT" + case sshFxpRename: + return "SSH_FXP_RENAME" + case sshFxpReadlink: + return "SSH_FXP_READLINK" + case sshFxpSymlink: + return "SSH_FXP_SYMLINK" + case sshFxpStatus: + return "SSH_FXP_STATUS" + case sshFxpHandle: + return "SSH_FXP_HANDLE" + case sshFxpData: + return "SSH_FXP_DATA" + case sshFxpName: + return "SSH_FXP_NAME" + case sshFxpAttrs: + return "SSH_FXP_ATTRS" + case sshFxpExtended: + return "SSH_FXP_EXTENDED" + case sshFxpExtendedReply: + return "SSH_FXP_EXTENDED_REPLY" + default: + return "unknown" + } +} + +type fx uint8 + +func (f fx) String() string { + switch f { + case sshFxOk: + return "SSH_FX_OK" + case sshFxEOF: + return "SSH_FX_EOF" + case sshFxNoSuchFile: + return "SSH_FX_NO_SUCH_FILE" + case sshFxPermissionDenied: + return "SSH_FX_PERMISSION_DENIED" + case sshFxFailure: + return "SSH_FX_FAILURE" + case sshFxBadMessage: + return "SSH_FX_BAD_MESSAGE" + case sshFxNoConnection: + return "SSH_FX_NO_CONNECTION" + case sshFxConnectionLost: + return "SSH_FX_CONNECTION_LOST" + case sshFxOPUnsupported: + return "SSH_FX_OP_UNSUPPORTED" + default: + return "unknown" + } +} + +type unexpectedPacketErr struct { + want, got uint8 +} + +func (u *unexpectedPacketErr) Error() string { + return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) +} + +func unimplementedPacketErr(u uint8) error { + return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) +} + +type unexpectedIDErr struct{ want, got uint32 } + +func (u *unexpectedIDErr) Error() string { + return fmt.Sprintf("sftp: unexpected id: want %d, got %d", u.want, u.got) +} + +func unimplementedSeekWhence(whence int) error { + return fmt.Errorf("sftp: unimplemented seek whence %d", whence) +} + +func unexpectedCount(want, got uint32) error { + return fmt.Errorf("sftp: unexpected count: want %d, got %d", want, got) +} + +type unexpectedVersionErr struct{ want, got uint32 } + +func (u *unexpectedVersionErr) Error() string { + return fmt.Sprintf("sftp: unexpected server version: want %v, got %v", u.want, u.got) +} + +// A StatusError is returned when an SFTP operation fails, and provides +// additional information about the failure. +type StatusError struct { + Code uint32 + msg, lang string +} + +func (s *StatusError) Error() string { + return fmt.Sprintf("sftp: %q (%v)", s.msg, fx(s.Code)) +} + +// FxCode returns the error code typed to match against the exported codes +func (s *StatusError) FxCode() fxerr { + return fxerr(s.Code) +} + +func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) { + for _, supportedExtension := range supportedSFTPExtensions { + if supportedExtension.Name == extensionName { + return supportedExtension, nil + } + } + return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName) +} + +// SetSFTPExtensions allows to customize the supported server extensions. +// See the variable supportedSFTPExtensions for supported extensions. +// This method accepts a slice of sshExtensionPair names for example 'hardlink@openssh.com'. +// If an invalid extension is given an error will be returned and nothing will be changed +func SetSFTPExtensions(extensions ...string) error { + tempExtensions := []sshExtensionPair{} + for _, extension := range extensions { + sftpExtension, err := getSupportedExtensionByName(extension) + if err != nil { + return err + } + tempExtensions = append(tempExtensions, sftpExtension) + } + sftpExtensions = tempExtensions + return nil +} diff --git a/sftp/stat_plan9.go b/sftp/stat_plan9.go new file mode 100644 index 0000000..761abdf --- /dev/null +++ b/sftp/stat_plan9.go @@ -0,0 +1,103 @@ +package sftp + +import ( + "os" + "syscall" +) + +var EBADF = syscall.NewError("fd out of range or not open") + +func wrapPathError(filepath string, err error) error { + if errno, ok := err.(syscall.ErrorString); ok { + return &os.PathError{Path: filepath, Err: errno} + } + return err +} + +// translateErrno translates a syscall error number to a SFTP error code. +func translateErrno(errno syscall.ErrorString) uint32 { + switch errno { + case "": + return sshFxOk + case syscall.ENOENT: + return sshFxNoSuchFile + case syscall.EPERM: + return sshFxPermissionDenied + } + + return sshFxFailure +} + +func translateSyscallError(err error) (uint32, bool) { + switch e := err.(type) { + case syscall.ErrorString: + return translateErrno(e), true + case *os.PathError: + debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err) + if errno, ok := e.Err.(syscall.ErrorString); ok { + return translateErrno(errno), true + } + } + return 0, false +} + +// isRegular returns true if the mode describes a regular file. +func isRegular(mode uint32) bool { + return mode&S_IFMT == syscall.S_IFREG +} + +// toFileMode converts sftp filemode bits to the os.FileMode specification +func toFileMode(mode uint32) os.FileMode { + var fm = os.FileMode(mode & 0777) + + switch mode & S_IFMT { + case syscall.S_IFBLK: + fm |= os.ModeDevice + case syscall.S_IFCHR: + fm |= os.ModeDevice | os.ModeCharDevice + case syscall.S_IFDIR: + fm |= os.ModeDir + case syscall.S_IFIFO: + fm |= os.ModeNamedPipe + case syscall.S_IFLNK: + fm |= os.ModeSymlink + case syscall.S_IFREG: + // nothing to do + case syscall.S_IFSOCK: + fm |= os.ModeSocket + } + + return fm +} + +// fromFileMode converts from the os.FileMode specification to sftp filemode bits +func fromFileMode(mode os.FileMode) uint32 { + ret := uint32(mode & os.ModePerm) + + switch mode & os.ModeType { + case os.ModeDevice | os.ModeCharDevice: + ret |= syscall.S_IFCHR + case os.ModeDevice: + ret |= syscall.S_IFBLK + case os.ModeDir: + ret |= syscall.S_IFDIR + case os.ModeNamedPipe: + ret |= syscall.S_IFIFO + case os.ModeSymlink: + ret |= syscall.S_IFLNK + case 0: + ret |= syscall.S_IFREG + case os.ModeSocket: + ret |= syscall.S_IFSOCK + } + + return ret +} + +// Plan 9 doesn't have setuid, setgid or sticky, but a Plan 9 client should +// be able to send these bits to a POSIX server. +const ( + s_ISUID = 04000 + s_ISGID = 02000 + s_ISVTX = 01000 +) diff --git a/sftp/stat_posix.go b/sftp/stat_posix.go new file mode 100644 index 0000000..5b870e2 --- /dev/null +++ b/sftp/stat_posix.go @@ -0,0 +1,124 @@ +//go:build !plan9 +// +build !plan9 + +package sftp + +import ( + "os" + "syscall" +) + +const EBADF = syscall.EBADF + +func wrapPathError(filepath string, err error) error { + if errno, ok := err.(syscall.Errno); ok { + return &os.PathError{Path: filepath, Err: errno} + } + return err +} + +// translateErrno translates a syscall error number to a SFTP error code. +func translateErrno(errno syscall.Errno) uint32 { + switch errno { + case 0: + return sshFxOk + case syscall.ENOENT: + return sshFxNoSuchFile + case syscall.EACCES, syscall.EPERM: + return sshFxPermissionDenied + } + + return sshFxFailure +} + +func translateSyscallError(err error) (uint32, bool) { + switch e := err.(type) { + case syscall.Errno: + return translateErrno(e), true + case *os.PathError: + debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err) + if errno, ok := e.Err.(syscall.Errno); ok { + return translateErrno(errno), true + } + } + return 0, false +} + +// isRegular returns true if the mode describes a regular file. +func isRegular(mode uint32) bool { + return mode&S_IFMT == syscall.S_IFREG +} + +// toFileMode converts sftp filemode bits to the os.FileMode specification +func toFileMode(mode uint32) os.FileMode { + var fm = os.FileMode(mode & 0777) + + switch mode & S_IFMT { + case syscall.S_IFBLK: + fm |= os.ModeDevice + case syscall.S_IFCHR: + fm |= os.ModeDevice | os.ModeCharDevice + case syscall.S_IFDIR: + fm |= os.ModeDir + case syscall.S_IFIFO: + fm |= os.ModeNamedPipe + case syscall.S_IFLNK: + fm |= os.ModeSymlink + case syscall.S_IFREG: + // nothing to do + case syscall.S_IFSOCK: + fm |= os.ModeSocket + } + + if mode&syscall.S_ISUID != 0 { + fm |= os.ModeSetuid + } + if mode&syscall.S_ISGID != 0 { + fm |= os.ModeSetgid + } + if mode&syscall.S_ISVTX != 0 { + fm |= os.ModeSticky + } + + return fm +} + +// fromFileMode converts from the os.FileMode specification to sftp filemode bits +func fromFileMode(mode os.FileMode) uint32 { + ret := uint32(mode & os.ModePerm) + + switch mode & os.ModeType { + case os.ModeDevice | os.ModeCharDevice: + ret |= syscall.S_IFCHR + case os.ModeDevice: + ret |= syscall.S_IFBLK + case os.ModeDir: + ret |= syscall.S_IFDIR + case os.ModeNamedPipe: + ret |= syscall.S_IFIFO + case os.ModeSymlink: + ret |= syscall.S_IFLNK + case 0: + ret |= syscall.S_IFREG + case os.ModeSocket: + ret |= syscall.S_IFSOCK + } + + if mode&os.ModeSetuid != 0 { + ret |= syscall.S_ISUID + } + if mode&os.ModeSetgid != 0 { + ret |= syscall.S_ISGID + } + if mode&os.ModeSticky != 0 { + ret |= syscall.S_ISVTX + } + + return ret +} + +const ( + s_ISUID = syscall.S_ISUID + s_ISGID = syscall.S_ISGID + s_ISVTX = syscall.S_ISVTX +) diff --git a/sftp/syscall_fixed.go b/sftp/syscall_fixed.go new file mode 100644 index 0000000..d404577 --- /dev/null +++ b/sftp/syscall_fixed.go @@ -0,0 +1,9 @@ +// +build plan9 windows js,wasm + +// Go defines S_IFMT on windows, plan9 and js/wasm as 0x1f000 instead of +// 0xf000. None of the the other S_IFxyz values include the "1" (in 0x1f000) +// which prevents them from matching the bitmask. + +package sftp + +const S_IFMT = 0xf000 diff --git a/sftp/syscall_good.go b/sftp/syscall_good.go new file mode 100644 index 0000000..4c2b240 --- /dev/null +++ b/sftp/syscall_good.go @@ -0,0 +1,8 @@ +// +build !plan9,!windows +// +build !js !wasm + +package sftp + +import "syscall" + +const S_IFMT = syscall.S_IFMT -- cgit v1.2.3