aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQuentin Dufour <quentin@deuxfleurs.fr>2021-11-19 19:54:49 +0100
committerQuentin Dufour <quentin@deuxfleurs.fr>2021-11-19 19:54:49 +0100
commit0ee29e31ddcc81f541de7459b0a5e40dfa552672 (patch)
tree859ff133f8c78bd034b0c2184cdad0ce9f38b065
parent93631b4e3d5195d446504db1c4a2bc7468b3ef28 (diff)
downloadbagage-0ee29e31ddcc81f541de7459b0a5e40dfa552672.tar.gz
bagage-0ee29e31ddcc81f541de7459b0a5e40dfa552672.zip
Working on SFTP
-rw-r--r--.gitignore3
-rw-r--r--auth_ldap.go84
-rw-r--r--config.go1
-rw-r--r--go.mod3
-rw-r--r--go.sum18
-rw-r--r--internal/encoding/ssh/filexfer/attrs.go326
-rw-r--r--internal/encoding/ssh/filexfer/attrs_test.go231
-rw-r--r--internal/encoding/ssh/filexfer/buffer.go293
-rw-r--r--internal/encoding/ssh/filexfer/extended_packets.go142
-rw-r--r--internal/encoding/ssh/filexfer/extended_packets_test.go240
-rw-r--r--internal/encoding/ssh/filexfer/extensions.go46
-rw-r--r--internal/encoding/ssh/filexfer/extensions_test.go49
-rw-r--r--internal/encoding/ssh/filexfer/filexfer.go54
-rw-r--r--internal/encoding/ssh/filexfer/fx.go147
-rw-r--r--internal/encoding/ssh/filexfer/fx_test.go102
-rw-r--r--internal/encoding/ssh/filexfer/fxp.go124
-rw-r--r--internal/encoding/ssh/filexfer/fxp_test.go84
-rw-r--r--internal/encoding/ssh/filexfer/handle_packets.go249
-rw-r--r--internal/encoding/ssh/filexfer/handle_packets_test.go282
-rw-r--r--internal/encoding/ssh/filexfer/init_packets.go99
-rw-r--r--internal/encoding/ssh/filexfer/init_packets_test.go114
-rw-r--r--internal/encoding/ssh/filexfer/open_packets.go89
-rw-r--r--internal/encoding/ssh/filexfer/open_packets_test.go107
-rw-r--r--internal/encoding/ssh/filexfer/openssh/fsync.go73
-rw-r--r--internal/encoding/ssh/filexfer/openssh/fsync_test.go62
-rw-r--r--internal/encoding/ssh/filexfer/openssh/hardlink.go79
-rw-r--r--internal/encoding/ssh/filexfer/openssh/hardlink_test.go69
-rw-r--r--internal/encoding/ssh/filexfer/openssh/openssh.go2
-rw-r--r--internal/encoding/ssh/filexfer/openssh/posix-rename.go79
-rw-r--r--internal/encoding/ssh/filexfer/openssh/posix-rename_test.go69
-rw-r--r--internal/encoding/ssh/filexfer/openssh/statvfs.go256
-rw-r--r--internal/encoding/ssh/filexfer/openssh/statvfs_test.go239
-rw-r--r--internal/encoding/ssh/filexfer/packets.go323
-rw-r--r--internal/encoding/ssh/filexfer/packets_test.go132
-rw-r--r--internal/encoding/ssh/filexfer/path_packets.go368
-rw-r--r--internal/encoding/ssh/filexfer/path_packets_test.go450
-rw-r--r--internal/encoding/ssh/filexfer/permissions.go114
-rw-r--r--internal/encoding/ssh/filexfer/response_packets.go243
-rw-r--r--internal/encoding/ssh/filexfer/response_packets_test.go296
-rw-r--r--main.go157
-rw-r--r--s3/file.go (renamed from s3_file.go)86
-rw-r--r--s3/fs.go (renamed from s3_fs.go)38
-rw-r--r--s3/path.go68
-rw-r--r--s3/stat.go (renamed from s3_stat.go)28
-rw-r--r--s3_path.go57
-rw-r--r--sftp/allocator.go101
-rw-r--r--sftp/attrs.go90
-rw-r--r--sftp/attrs_stubs.go11
-rw-r--r--sftp/attrs_unix.go16
-rw-r--r--sftp/client.go1936
-rw-r--r--sftp/conn.go189
-rw-r--r--sftp/debug.go9
-rw-r--r--sftp/ls_formatting.go81
-rw-r--r--sftp/ls_plan9.go21
-rw-r--r--sftp/ls_stub.go11
-rw-r--r--sftp/ls_unix.go23
-rw-r--r--sftp/packet.go1276
-rw-r--r--sftp/packet_manager.go221
-rw-r--r--sftp/packet_typing.go135
-rw-r--r--sftp/pool.go79
-rw-r--r--sftp/release.go5
-rw-r--r--sftp/request-attrs.go63
-rw-r--r--sftp/request-errors.go54
-rw-r--r--sftp/request-interfaces.go121
-rw-r--r--sftp/request-plan9.go34
-rw-r--r--sftp/request-server.go304
-rw-r--r--sftp/request-unix.go27
-rw-r--r--sftp/request.go628
-rw-r--r--sftp/request_windows.go44
-rw-r--r--sftp/server.go643
-rw-r--r--sftp/server_statvfs_darwin.go21
-rw-r--r--sftp/server_statvfs_impl.go29
-rw-r--r--sftp/server_statvfs_linux.go22
-rw-r--r--sftp/server_statvfs_plan9.go13
-rw-r--r--sftp/server_statvfs_stubs.go15
-rw-r--r--sftp/sftp.go258
-rw-r--r--sftp/stat_plan9.go103
-rw-r--r--sftp/stat_posix.go124
-rw-r--r--sftp/syscall_fixed.go9
-rw-r--r--sftp/syscall_good.go8
-rw-r--r--webdav.go3
81 files changed, 12748 insertions, 154 deletions
diff --git a/.gitignore b/.gitignore
index 76cf493..dacbb71 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,5 @@
bagage
.env
+*.swp
+id_rsa
+id_rsa.pub
diff --git a/auth_ldap.go b/auth_ldap.go
index bf2a9fb..26d3565 100644
--- a/auth_ldap.go
+++ b/auth_ldap.go
@@ -18,41 +18,17 @@ type LdapPreAuth struct {
func (l LdapPreAuth) WithCreds(username, password string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var e *LdapWrongPasswordError
- // 1. Connect to the server
- conn, err := ldapConnect(l.WithConfig)
- if err != nil {
- l.OnFailure.WithError(err).ServeHTTP(w, r)
- return
- }
- defer conn.Close()
-
- // 2. Authenticate with provided credentials
- // @FIXME we should better check the error, it could also be due to an LDAP error
- err = conn.auth(username, password)
- if err != nil {
+ access_key, secret_key, err := LdapGetS3(l.WithConfig, username, password)
+
+ if err == nil {
+ l.OnCreds.WithCreds(access_key, secret_key).ServeHTTP(w, r)
+ } else if errors.As(err, &e) {
l.OnWrongPassword.WithError(err).ServeHTTP(w, r)
- return
- }
-
- // 3. Fetch user's profile
- profile, err := conn.profile()
- if err != nil {
- l.OnFailure.WithError(err).ServeHTTP(w, r)
- return
- }
-
- // 4. Basic checks upon users' attributes
- access_key := profile.GetAttributeValue("garage_s3_access_key")
- secret_key := profile.GetAttributeValue("garage_s3_secret_key")
- if access_key == "" || secret_key == "" {
- err = errors.New(fmt.Sprintf("Either access key or secret key is missing in LDAP for %s", conn.userDn))
- l.OnFailure.WithError(err).ServeHTTP(w, r)
- return
+ } else {
+ l.OnFailure.WithError(e).ServeHTTP(w, r)
}
-
- // 5. Send fetched credentials to the next middleware
- l.OnCreds.WithCreds(access_key, secret_key).ServeHTTP(w, r)
})
}
@@ -66,6 +42,50 @@ type ldapConnector struct {
userDn string
}
+type LdapError struct {
+ Username string
+ Err error
+}
+func (e *LdapError) Error() string { return "ldap error for "+e.Username+": "+e.Err.Error() }
+type LdapWrongPasswordError struct { LdapError }
+
+func LdapGetS3(c *Config, username, password string) (access_key, secret_key string, werr error) {
+ // 1. Connect to the server
+ conn, err := ldapConnect(c)
+ if err != nil {
+ werr = &LdapError { username, err }
+ return
+ }
+ defer conn.Close()
+
+ // 2. Authenticate with provided credentials
+ // @FIXME we should better check the error, it could also be due to an LDAP error
+ err = conn.auth(username, password)
+ if err != nil {
+ werr = &LdapWrongPasswordError { LdapError { username, err } }
+ return
+ }
+
+ // 3. Fetch user's profile
+ profile, err := conn.profile()
+ if err != nil {
+ werr = &LdapError { username, err }
+ return
+ }
+
+ // 4. Basic checks upon users' attributes
+ access_key = profile.GetAttributeValue("garage_s3_access_key")
+ secret_key = profile.GetAttributeValue("garage_s3_secret_key")
+ if access_key == "" || secret_key == "" {
+ err = errors.New(fmt.Sprintf("Either access key or secret key is missing in LDAP for %s", conn.userDn))
+ werr = &LdapError { username, err }
+ return
+ }
+
+ // 5. Send fetched credentials to the next middleware
+ return
+}
+
func ldapConnect(c *Config) (ldapConnector, error) {
ldapSock, err := ldap.Dial("tcp", c.LdapServer)
if err != nil {
diff --git a/config.go b/config.go
index 002f317..b660e68 100644
--- a/config.go
+++ b/config.go
@@ -14,6 +14,7 @@ type Config struct {
UserNameAttr string `env:"BAGAGE_LDAP_USERNAME_ATTR" default:"cn"`
Endpoint string `env:"BAGAGE_S3_ENDPOINT" default:"garage.deuxfleurs.fr"`
UseSSL bool `env:"BAGAGE_S3_SSL" default:"true"`
+ SSHKey string `env:"BAGAGE_SSH_KEY" default:"id_rsa"`
}
func (c *Config) LoadWithDefault() *Config {
diff --git a/go.mod b/go.mod
index 546f5ef..7d2b49b 100644
--- a/go.mod
+++ b/go.mod
@@ -4,6 +4,9 @@ go 1.16
require (
github.com/go-ldap/ldap/v3 v3.4.1
+ github.com/kr/fs v0.1.0
github.com/minio/minio-go/v7 v7.0.12
+ github.com/pkg/sftp v1.13.4
+ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d
)
diff --git a/go.sum b/go.sum
index 57620ee..bdc858d 100644
--- a/go.sum
+++ b/go.sum
@@ -21,6 +21,8 @@ github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfV
github.com/klauspost/cpuid v1.2.3/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s=
github.com/klauspost/cpuid v1.3.1/go.mod h1:bYW4mA6ZgKPob1/Dlai2LviZJO7KGI3uoWLd42rAQw4=
+github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
+github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
@@ -38,6 +40,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/pkg/sftp v1.13.4 h1:Lb0RYJCmgUcBgZosfoi9Y9sbl6+LJgOIgk/2Y4YjMFg=
+github.com/pkg/sftp v1.13.4/go.mod h1:LzqnAvaD5TWeNBsZpfKxSYn1MbjWwOsCIAFFJbpIsK8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc=
@@ -51,16 +55,19 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f h1:aZp0e2vLN4MToVqnjNEYEtrEA8RH8U8FN1CU7JgqsPU=
golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
+golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b h1:7mWr3k41Qtv8XlltBkDkl8LoP3mpSgBW8BUoxtEdbXg=
+golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d h1:LO7XpTYMwTqxjLcGWPijK3vRXg1aWdlNOVOHRq45d7c=
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -69,9 +76,11 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 h1:iGu644GcxtEcrInvDsQRCwJjtCIOlT2V7IRt6ah2Whw=
+golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -84,5 +93,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/ini.v1 v1.57.0 h1:9unxIsFcTt4I55uWluz+UmL95q4kdJ0buvQ1ZIqVQww=
gopkg.in/ini.v1 v1.57.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/encoding/ssh/filexfer/attrs.go b/internal/encoding/ssh/filexfer/attrs.go
new file mode 100644
index 0000000..1d4bb79
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/attrs.go
@@ -0,0 +1,326 @@
+package filexfer
+
+// Attributes related flags.
+const (
+ AttrSize = 1 << iota // SSH_FILEXFER_ATTR_SIZE
+ AttrUIDGID // SSH_FILEXFER_ATTR_UIDGID
+ AttrPermissions // SSH_FILEXFER_ATTR_PERMISSIONS
+ AttrACModTime // SSH_FILEXFER_ACMODTIME
+
+ AttrExtended = 1 << 31 // SSH_FILEXFER_ATTR_EXTENDED
+)
+
+// Attributes defines the file attributes type defined in draft-ietf-secsh-filexfer-02
+//
+// Defined in: https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-5
+type Attributes struct {
+ Flags uint32
+
+ // AttrSize
+ Size uint64
+
+ // AttrUIDGID
+ UID uint32
+ GID uint32
+
+ // AttrPermissions
+ Permissions FileMode
+
+ // AttrACmodTime
+ ATime uint32
+ MTime uint32
+
+ // AttrExtended
+ ExtendedAttributes []ExtendedAttribute
+}
+
+// GetSize returns the Size field and a bool that is true if and only if the value is valid/defined.
+func (a *Attributes) GetSize() (size uint64, ok bool) {
+ return a.Size, a.Flags&AttrSize != 0
+}
+
+// SetSize is a convenience function that sets the Size field,
+// and marks the field as valid/defined in Flags.
+func (a *Attributes) SetSize(size uint64) {
+ a.Flags |= AttrSize
+ a.Size = size
+}
+
+// GetUIDGID returns the UID and GID fields and a bool that is true if and only if the values are valid/defined.
+func (a *Attributes) GetUIDGID() (uid, gid uint32, ok bool) {
+ return a.UID, a.GID, a.Flags&AttrUIDGID != 0
+}
+
+// SetUIDGID is a convenience function that sets the UID and GID fields,
+// and marks the fields as valid/defined in Flags.
+func (a *Attributes) SetUIDGID(uid, gid uint32) {
+ a.Flags |= AttrUIDGID
+ a.UID = uid
+ a.GID = gid
+}
+
+// GetPermissions returns the Permissions field and a bool that is true if and only if the value is valid/defined.
+func (a *Attributes) GetPermissions() (perms FileMode, ok bool) {
+ return a.Permissions, a.Flags&AttrPermissions != 0
+}
+
+// SetPermissions is a convenience function that sets the Permissions field,
+// and marks the field as valid/defined in Flags.
+func (a *Attributes) SetPermissions(perms FileMode) {
+ a.Flags |= AttrPermissions
+ a.Permissions = perms
+}
+
+// GetACModTime returns the ATime and MTime fields and a bool that is true if and only if the values are valid/defined.
+func (a *Attributes) GetACModTime() (atime, mtime uint32, ok bool) {
+ return a.ATime, a.MTime, a.Flags&AttrACModTime != 0
+ return a.ATime, a.MTime, a.Flags&AttrACModTime != 0
+}
+
+// SetACModTime is a convenience function that sets the ATime and MTime fields,
+// and marks the fields as valid/defined in Flags.
+func (a *Attributes) SetACModTime(atime, mtime uint32) {
+ a.Flags |= AttrACModTime
+ a.ATime = atime
+ a.MTime = mtime
+}
+
+// Len returns the number of bytes a would marshal into.
+func (a *Attributes) Len() int {
+ length := 4
+
+ if a.Flags&AttrSize != 0 {
+ length += 8
+ }
+
+ if a.Flags&AttrUIDGID != 0 {
+ length += 4 + 4
+ }
+
+ if a.Flags&AttrPermissions != 0 {
+ length += 4
+ }
+
+ if a.Flags&AttrACModTime != 0 {
+ length += 4 + 4
+ }
+
+ if a.Flags&AttrExtended != 0 {
+ length += 4
+
+ for _, ext := range a.ExtendedAttributes {
+ length += ext.Len()
+ }
+ }
+
+ return length
+}
+
+// MarshalInto marshals e onto the end of the given Buffer.
+func (a *Attributes) MarshalInto(b *Buffer) {
+ b.AppendUint32(a.Flags)
+
+ if a.Flags&AttrSize != 0 {
+ b.AppendUint64(a.Size)
+ }
+
+ if a.Flags&AttrUIDGID != 0 {
+ b.AppendUint32(a.UID)
+ b.AppendUint32(a.GID)
+ }
+
+ if a.Flags&AttrPermissions != 0 {
+ b.AppendUint32(uint32(a.Permissions))
+ }
+
+ if a.Flags&AttrACModTime != 0 {
+ b.AppendUint32(a.ATime)
+ b.AppendUint32(a.MTime)
+ }
+
+ if a.Flags&AttrExtended != 0 {
+ b.AppendUint32(uint32(len(a.ExtendedAttributes)))
+
+ for _, ext := range a.ExtendedAttributes {
+ ext.MarshalInto(b)
+ }
+ }
+}
+
+// MarshalBinary returns a as the binary encoding of a.
+func (a *Attributes) MarshalBinary() ([]byte, error) {
+ buf := NewBuffer(make([]byte, 0, a.Len()))
+ a.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom unmarshals an Attributes from the given Buffer into e.
+//
+// NOTE: The values of fields not covered in the a.Flags are explicitly undefined.
+func (a *Attributes) UnmarshalFrom(b *Buffer) (err error) {
+ flags, err := b.ConsumeUint32()
+ if err != nil {
+ return err
+ }
+
+ return a.XXX_UnmarshalByFlags(flags, b)
+}
+
+// XXX_UnmarshalByFlags uses the pre-existing a.Flags field to determine which fields to decode.
+// DO NOT USE THIS: it is an anti-corruption function to implement existing internal usage in pkg/sftp.
+// This function is not a part of any compatibility promise.
+func (a *Attributes) XXX_UnmarshalByFlags(flags uint32, b *Buffer) (err error) {
+ a.Flags = flags
+
+ // Short-circuit dummy attributes.
+ if a.Flags == 0 {
+ return nil
+ }
+
+ if a.Flags&AttrSize != 0 {
+ if a.Size, err = b.ConsumeUint64(); err != nil {
+ return err
+ }
+ }
+
+ if a.Flags&AttrUIDGID != 0 {
+ if a.UID, err = b.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ if a.GID, err = b.ConsumeUint32(); err != nil {
+ return err
+ }
+ }
+
+ if a.Flags&AttrPermissions != 0 {
+ m, err := b.ConsumeUint32()
+ if err != nil {
+ return err
+ }
+
+ a.Permissions = FileMode(m)
+ }
+
+ if a.Flags&AttrACModTime != 0 {
+ if a.ATime, err = b.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ if a.MTime, err = b.ConsumeUint32(); err != nil {
+ return err
+ }
+ }
+
+ if a.Flags&AttrExtended != 0 {
+ count, err := b.ConsumeUint32()
+ if err != nil {
+ return err
+ }
+
+ a.ExtendedAttributes = make([]ExtendedAttribute, count)
+ for i := range a.ExtendedAttributes {
+ a.ExtendedAttributes[i].UnmarshalFrom(b)
+ }
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the binary encoding of Attributes into e.
+func (a *Attributes) UnmarshalBinary(data []byte) error {
+ return a.UnmarshalFrom(NewBuffer(data))
+}
+
+// ExtendedAttribute defines the extended file attribute type defined in draft-ietf-secsh-filexfer-02
+//
+// Defined in: https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-5
+type ExtendedAttribute struct {
+ Type string
+ Data string
+}
+
+// Len returns the number of bytes e would marshal into.
+func (e *ExtendedAttribute) Len() int {
+ return 4 + len(e.Type) + 4 + len(e.Data)
+}
+
+// MarshalInto marshals e onto the end of the given Buffer.
+func (e *ExtendedAttribute) MarshalInto(b *Buffer) {
+ b.AppendString(e.Type)
+ b.AppendString(e.Data)
+}
+
+// MarshalBinary returns e as the binary encoding of e.
+func (e *ExtendedAttribute) MarshalBinary() ([]byte, error) {
+ buf := NewBuffer(make([]byte, 0, e.Len()))
+ e.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom unmarshals an ExtendedAattribute from the given Buffer into e.
+func (e *ExtendedAttribute) UnmarshalFrom(b *Buffer) (err error) {
+ if e.Type, err = b.ConsumeString(); err != nil {
+ return err
+ }
+
+ if e.Data, err = b.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the binary encoding of ExtendedAttribute into e.
+func (e *ExtendedAttribute) UnmarshalBinary(data []byte) error {
+ return e.UnmarshalFrom(NewBuffer(data))
+}
+
+// NameEntry implements the SSH_FXP_NAME repeated data type from draft-ietf-secsh-filexfer-02
+//
+// This type is incompatible with versions 4 or higher.
+type NameEntry struct {
+ Filename string
+ Longname string
+ Attrs Attributes
+}
+
+// Len returns the number of bytes e would marshal into.
+func (e *NameEntry) Len() int {
+ return 4 + len(e.Filename) + 4 + len(e.Longname) + e.Attrs.Len()
+}
+
+// MarshalInto marshals e onto the end of the given Buffer.
+func (e *NameEntry) MarshalInto(b *Buffer) {
+ b.AppendString(e.Filename)
+ b.AppendString(e.Longname)
+
+ e.Attrs.MarshalInto(b)
+}
+
+// MarshalBinary returns e as the binary encoding of e.
+func (e *NameEntry) MarshalBinary() ([]byte, error) {
+ buf := NewBuffer(make([]byte, 0, e.Len()))
+ e.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom unmarshals an NameEntry from the given Buffer into e.
+//
+// NOTE: The values of fields not covered in the a.Flags are explicitly undefined.
+func (e *NameEntry) UnmarshalFrom(b *Buffer) (err error) {
+ if e.Filename, err = b.ConsumeString(); err != nil {
+ return err
+ }
+
+ if e.Longname, err = b.ConsumeString(); err != nil {
+ return err
+ }
+
+ return e.Attrs.UnmarshalFrom(b)
+}
+
+// UnmarshalBinary decodes the binary encoding of NameEntry into e.
+func (e *NameEntry) UnmarshalBinary(data []byte) error {
+ return e.UnmarshalFrom(NewBuffer(data))
+}
diff --git a/internal/encoding/ssh/filexfer/attrs_test.go b/internal/encoding/ssh/filexfer/attrs_test.go
new file mode 100644
index 0000000..c03015c
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/attrs_test.go
@@ -0,0 +1,231 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestAttributes(t *testing.T) {
+ const (
+ size = 0x123456789ABCDEF0
+ uid = 1000
+ gid = 100
+ perms = 0x87654321
+ atime = 0x2A2B2C2D
+ mtime = 0x42434445
+ )
+
+ extAttr := ExtendedAttribute{
+ Type: "foo",
+ Data: "bar",
+ }
+
+ attr := &Attributes{
+ Size: size,
+ UID: uid,
+ GID: gid,
+ Permissions: perms,
+ ATime: atime,
+ MTime: mtime,
+ ExtendedAttributes: []ExtendedAttribute{
+ extAttr,
+ },
+ }
+
+ type test struct {
+ name string
+ flags uint32
+ encoded []byte
+ }
+
+ tests := []test{
+ {
+ name: "empty",
+ encoded: []byte{
+ 0x00, 0x00, 0x00, 0x00,
+ },
+ },
+ {
+ name: "size",
+ flags: AttrSize,
+ encoded: []byte{
+ 0x00, 0x00, 0x00, 0x01,
+ 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0,
+ },
+ },
+ {
+ name: "uidgid",
+ flags: AttrUIDGID,
+ encoded: []byte{
+ 0x00, 0x00, 0x00, 0x02,
+ 0x00, 0x00, 0x03, 0xE8,
+ 0x00, 0x00, 0x00, 100,
+ },
+ },
+ {
+ name: "permissions",
+ flags: AttrPermissions,
+ encoded: []byte{
+ 0x00, 0x00, 0x00, 0x04,
+ 0x87, 0x65, 0x43, 0x21,
+ },
+ },
+ {
+ name: "acmodtime",
+ flags: AttrACModTime,
+ encoded: []byte{
+ 0x00, 0x00, 0x00, 0x08,
+ 0x2A, 0x2B, 0x2C, 0x2D,
+ 0x42, 0x43, 0x44, 0x45,
+ },
+ },
+ {
+ name: "extended",
+ flags: AttrExtended,
+ encoded: []byte{
+ 0x80, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01,
+ 0x00, 0x00, 0x00, 0x03, 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r',
+ },
+ },
+ {
+ name: "size uidgid permisssions acmodtime extended",
+ flags: AttrSize | AttrUIDGID | AttrPermissions | AttrACModTime | AttrExtended,
+ encoded: []byte{
+ 0x80, 0x00, 0x00, 0x0F,
+ 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0,
+ 0x00, 0x00, 0x03, 0xE8,
+ 0x00, 0x00, 0x00, 100,
+ 0x87, 0x65, 0x43, 0x21,
+ 0x2A, 0x2B, 0x2C, 0x2D,
+ 0x42, 0x43, 0x44, 0x45,
+ 0x00, 0x00, 0x00, 0x01,
+ 0x00, 0x00, 0x00, 0x03, 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r',
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ attr := *attr
+
+ t.Run(tt.name, func(t *testing.T) {
+ attr.Flags = tt.flags
+
+ buf, err := attr.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if !bytes.Equal(buf, tt.encoded) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, tt.encoded)
+ }
+
+ attr = Attributes{}
+
+ if err := attr.UnmarshalBinary(buf); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if attr.Flags != tt.flags {
+ t.Errorf("UnmarshalBinary(): Flags was %x, but wanted %x", attr.Flags, tt.flags)
+ }
+
+ if attr.Flags&AttrSize != 0 && attr.Size != size {
+ t.Errorf("UnmarshalBinary(): Size was %x, but wanted %x", attr.Size, size)
+ }
+
+ if attr.Flags&AttrUIDGID != 0 {
+ if attr.UID != uid {
+ t.Errorf("UnmarshalBinary(): UID was %x, but wanted %x", attr.UID, uid)
+ }
+
+ if attr.GID != gid {
+ t.Errorf("UnmarshalBinary(): GID was %x, but wanted %x", attr.GID, gid)
+ }
+ }
+
+ if attr.Flags&AttrPermissions != 0 && attr.Permissions != perms {
+ t.Errorf("UnmarshalBinary(): Permissions was %#v, but wanted %#v", attr.Permissions, perms)
+ }
+
+ if attr.Flags&AttrACModTime != 0 {
+ if attr.ATime != atime {
+ t.Errorf("UnmarshalBinary(): ATime was %x, but wanted %x", attr.ATime, atime)
+ }
+
+ if attr.MTime != mtime {
+ t.Errorf("UnmarshalBinary(): MTime was %x, but wanted %x", attr.MTime, mtime)
+ }
+ }
+
+ if attr.Flags&AttrExtended != 0 {
+ extAttrs := attr.ExtendedAttributes
+
+ if count := len(extAttrs); count != 1 {
+ t.Fatalf("UnmarshalBinary(): len(ExtendedAttributes) was %d, but wanted %d", count, 1)
+ }
+
+ if got := extAttrs[0]; got != extAttr {
+ t.Errorf("UnmarshalBinary(): ExtendedAttributes[0] was %#v, but wanted %#v", got, extAttr)
+ }
+ }
+ })
+ }
+}
+
+func TestNameEntry(t *testing.T) {
+ const (
+ filename = "foo"
+ longname = "bar"
+ perms = 0x87654321
+ )
+
+ e := &NameEntry{
+ Filename: filename,
+ Longname: longname,
+ Attrs: Attributes{
+ Flags: AttrPermissions,
+ Permissions: perms,
+ },
+ }
+
+ buf, err := e.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 0x03, 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r',
+ 0x00, 0x00, 0x00, 0x04,
+ 0x87, 0x65, 0x43, 0x21,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want)
+ }
+
+ *e = NameEntry{}
+
+ if err := e.UnmarshalBinary(buf); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if e.Filename != filename {
+ t.Errorf("UnmarhsalFrom(): Filename was %q, but expected %q", e.Filename, filename)
+ }
+
+ if e.Longname != longname {
+ t.Errorf("UnmarhsalFrom(): Longname was %q, but expected %q", e.Longname, longname)
+ }
+
+ if e.Attrs.Flags != AttrPermissions {
+ t.Errorf("UnmarshalBinary(): Attrs.Flag was %#x, but expected %#x", e.Attrs.Flags, AttrPermissions)
+ }
+
+ if e.Attrs.Permissions != perms {
+ t.Errorf("UnmarshalBinary(): Attrs.Permissions was %#v, but expected %#v", e.Attrs.Permissions, perms)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/buffer.go b/internal/encoding/ssh/filexfer/buffer.go
new file mode 100644
index 0000000..a608603
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/buffer.go
@@ -0,0 +1,293 @@
+package filexfer
+
+import (
+ "encoding/binary"
+ "errors"
+)
+
+// Various encoding errors.
+var (
+ ErrShortPacket = errors.New("packet too short")
+ ErrLongPacket = errors.New("packet too long")
+)
+
+// Buffer wraps up the various encoding details of the SSH format.
+//
+// Data types are encoded as per section 4 from https://tools.ietf.org/html/draft-ietf-secsh-architecture-09#page-8
+type Buffer struct {
+ b []byte
+ off int
+}
+
+// NewBuffer creates and initializes a new buffer using buf as its initial contents.
+// The new buffer takes ownership of buf, and the caller should not use buf after this call.
+//
+// In most cases, new(Buffer) (or just declaring a Buffer variable) is sufficient to initialize a Buffer.
+func NewBuffer(buf []byte) *Buffer {
+ return &Buffer{
+ b: buf,
+ }
+}
+
+// NewMarshalBuffer creates a new Buffer ready to start marshaling a Packet into.
+// It preallocates enough space for uint32(length), uint8(type), uint32(request-id) and size more bytes.
+func NewMarshalBuffer(size int) *Buffer {
+ return NewBuffer(make([]byte, 4+1+4+size))
+}
+
+// Bytes returns a slice of length b.Len() holding the unconsumed bytes in the Buffer.
+// The slice is valid for use only until the next buffer modification
+// (that is, only until the next call to an Append or Consume method).
+func (b *Buffer) Bytes() []byte {
+ return b.b[b.off:]
+}
+
+// Len returns the number of unconsumed bytes in the buffer.
+func (b *Buffer) Len() int { return len(b.b) - b.off }
+
+// Cap returns the capacity of the buffer’s underlying byte slice,
+// that is, the total space allocated for the buffer’s data.
+func (b *Buffer) Cap() int { return cap(b.b) }
+
+// Reset resets the buffer to be empty, but it retains the underlying storage for use by future Appends.
+func (b *Buffer) Reset() {
+ b.b = b.b[:0]
+ b.off = 0
+}
+
+// StartPacket resets and initializes the buffer to be ready to start marshaling a packet into.
+// It truncates the buffer, reserves space for uint32(length), then appends the given packetType and requestID.
+func (b *Buffer) StartPacket(packetType PacketType, requestID uint32) {
+ b.b, b.off = append(b.b[:0], make([]byte, 4)...), 0
+
+ b.AppendUint8(uint8(packetType))
+ b.AppendUint32(requestID)
+}
+
+// Packet finalizes the packet started from StartPacket.
+// It is expected that this will end the ownership of the underlying byte-slice,
+// and so the returned byte-slices may be reused the same as any other byte-slice,
+// the caller should not use this buffer after this call.
+//
+// It writes the packet body length into the first four bytes of the buffer in network byte order (big endian).
+// The packet body length is the length of this buffer less the 4-byte length itself, plus the length of payload.
+//
+// It is assumed that no Consume methods have been called on this buffer,
+// and so it returns the whole underlying slice.
+func (b *Buffer) Packet(payload []byte) (header, payloadPassThru []byte, err error) {
+ b.PutLength(len(b.b) - 4 + len(payload))
+
+ return b.b, payload, nil
+}
+
+// ConsumeUint8 consumes a single byte from the buffer.
+// If the buffer does not have enough data, it will return ErrShortPacket.
+func (b *Buffer) ConsumeUint8() (uint8, error) {
+ if b.Len() < 1 {
+ return 0, ErrShortPacket
+ }
+
+ var v uint8
+ v, b.off = b.b[b.off], b.off+1
+ return v, nil
+}
+
+// AppendUint8 appends a single byte into the buffer.
+func (b *Buffer) AppendUint8(v uint8) {
+ b.b = append(b.b, v)
+}
+
+// ConsumeBool consumes a single byte from the buffer, and returns true if that byte is non-zero.
+// If the buffer does not have enough data, it will return ErrShortPacket.
+func (b *Buffer) ConsumeBool() (bool, error) {
+ v, err := b.ConsumeUint8()
+ if err != nil {
+ return false, err
+ }
+
+ return v != 0, nil
+}
+
+// AppendBool appends a single bool into the buffer.
+// It encodes it as a single byte, with false as 0, and true as 1.
+func (b *Buffer) AppendBool(v bool) {
+ if v {
+ b.AppendUint8(1)
+ } else {
+ b.AppendUint8(0)
+ }
+}
+
+// ConsumeUint16 consumes a single uint16 from the buffer, in network byte order (big-endian).
+// If the buffer does not have enough data, it will return ErrShortPacket.
+func (b *Buffer) ConsumeUint16() (uint16, error) {
+ if b.Len() < 2 {
+ return 0, ErrShortPacket
+ }
+
+ v := binary.BigEndian.Uint16(b.b[b.off:])
+ b.off += 2
+ return v, nil
+}
+
+// AppendUint16 appends single uint16 into the buffer, in network byte order (big-endian).
+func (b *Buffer) AppendUint16(v uint16) {
+ b.b = append(b.b,
+ byte(v>>8),
+ byte(v>>0),
+ )
+}
+
+// unmarshalUint32 is used internally to read the packet length.
+// It is unsafe, and so not exported.
+// Even within this package, its use should be avoided.
+func unmarshalUint32(b []byte) uint32 {
+ return binary.BigEndian.Uint32(b[:4])
+}
+
+// ConsumeUint32 consumes a single uint32 from the buffer, in network byte order (big-endian).
+// If the buffer does not have enough data, it will return ErrShortPacket.
+func (b *Buffer) ConsumeUint32() (uint32, error) {
+ if b.Len() < 4 {
+ return 0, ErrShortPacket
+ }
+
+ v := binary.BigEndian.Uint32(b.b[b.off:])
+ b.off += 4
+ return v, nil
+}
+
+// AppendUint32 appends a single uint32 into the buffer, in network byte order (big-endian).
+func (b *Buffer) AppendUint32(v uint32) {
+ b.b = append(b.b,
+ byte(v>>24),
+ byte(v>>16),
+ byte(v>>8),
+ byte(v>>0),
+ )
+}
+
+// ConsumeUint64 consumes a single uint64 from the buffer, in network byte order (big-endian).
+// If the buffer does not have enough data, it will return ErrShortPacket.
+func (b *Buffer) ConsumeUint64() (uint64, error) {
+ if b.Len() < 8 {
+ return 0, ErrShortPacket
+ }
+
+ v := binary.BigEndian.Uint64(b.b[b.off:])
+ b.off += 8
+ return v, nil
+}
+
+// AppendUint64 appends a single uint64 into the buffer, in network byte order (big-endian).
+func (b *Buffer) AppendUint64(v uint64) {
+ b.b = append(b.b,
+ byte(v>>56),
+ byte(v>>48),
+ byte(v>>40),
+ byte(v>>32),
+ byte(v>>24),
+ byte(v>>16),
+ byte(v>>8),
+ byte(v>>0),
+ )
+}
+
+// ConsumeInt64 consumes a single int64 from the buffer, in network byte order (big-endian) with two’s complement.
+// If the buffer does not have enough data, it will return ErrShortPacket.
+func (b *Buffer) ConsumeInt64() (int64, error) {
+ u, err := b.ConsumeUint64()
+ if err != nil {
+ return 0, err
+ }
+
+ return int64(u), err
+}
+
+// AppendInt64 appends a single int64 into the buffer, in network byte order (big-endian) with two’s complement.
+func (b *Buffer) AppendInt64(v int64) {
+ b.AppendUint64(uint64(v))
+}
+
+// ConsumeByteSlice consumes a single string of raw binary data from the buffer.
+// A string is a uint32 length, followed by that number of raw bytes.
+// If the buffer does not have enough data, or defines a length larger than available, it will return ErrShortPacket.
+//
+// The returned slice aliases the buffer contents, and is valid only as long as the buffer is not reused
+// (that is, only until the next call to Reset, PutLength, StartPacket, or UnmarshalBinary).
+//
+// In no case will any Consume calls return overlapping slice aliases,
+// and Append calls are guaranteed to not disturb this slice alias.
+func (b *Buffer) ConsumeByteSlice() ([]byte, error) {
+ length, err := b.ConsumeUint32()
+ if err != nil {
+ return nil, err
+ }
+
+ if b.Len() < int(length) {
+ return nil, ErrShortPacket
+ }
+
+ v := b.b[b.off:]
+ if len(v) > int(length) {
+ v = v[:length:length]
+ }
+ b.off += int(length)
+ return v, nil
+}
+
+// AppendByteSlice appends a single string of raw binary data into the buffer.
+// A string is a uint32 length, followed by that number of raw bytes.
+func (b *Buffer) AppendByteSlice(v []byte) {
+ b.AppendUint32(uint32(len(v)))
+ b.b = append(b.b, v...)
+}
+
+// ConsumeString consumes a single string of binary data from the buffer.
+// A string is a uint32 length, followed by that number of raw bytes.
+// If the buffer does not have enough data, or defines a length larger than available, it will return ErrShortPacket.
+//
+// NOTE: Go implicitly assumes that strings contain UTF-8 encoded data.
+// All caveats on using arbitrary binary data in Go strings applies.
+func (b *Buffer) ConsumeString() (string, error) {
+ v, err := b.ConsumeByteSlice()
+ if err != nil {
+ return "", err
+ }
+
+ return string(v), nil
+}
+
+// AppendString appends a single string of binary data into the buffer.
+// A string is a uint32 length, followed by that number of raw bytes.
+func (b *Buffer) AppendString(v string) {
+ b.AppendByteSlice([]byte(v))
+}
+
+// PutLength writes the given size into the first four bytes of the buffer in network byte order (big endian).
+func (b *Buffer) PutLength(size int) {
+ if len(b.b) < 4 {
+ b.b = append(b.b, make([]byte, 4-len(b.b))...)
+ }
+
+ binary.BigEndian.PutUint32(b.b, uint32(size))
+}
+
+// MarshalBinary returns a clone of the full internal buffer.
+func (b *Buffer) MarshalBinary() ([]byte, error) {
+ clone := make([]byte, len(b.b))
+ n := copy(clone, b.b)
+ return clone[:n], nil
+}
+
+// UnmarshalBinary sets the internal buffer of b to be a clone of data, and zeros the internal offset.
+func (b *Buffer) UnmarshalBinary(data []byte) error {
+ if grow := len(data) - len(b.b); grow > 0 {
+ b.b = append(b.b, make([]byte, grow)...)
+ }
+
+ n := copy(b.b, data)
+ b.b = b.b[:n]
+ b.off = 0
+ return nil
+}
diff --git a/internal/encoding/ssh/filexfer/extended_packets.go b/internal/encoding/ssh/filexfer/extended_packets.go
new file mode 100644
index 0000000..6b7b2ce
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/extended_packets.go
@@ -0,0 +1,142 @@
+package filexfer
+
+import (
+ "encoding"
+ "sync"
+)
+
+// ExtendedData aliases the untyped interface composition of encoding.BinaryMarshaler and encoding.BinaryUnmarshaler.
+type ExtendedData = interface {
+ encoding.BinaryMarshaler
+ encoding.BinaryUnmarshaler
+}
+
+// ExtendedDataConstructor defines a function that returns a new(ArbitraryExtendedPacket).
+type ExtendedDataConstructor func() ExtendedData
+
+var extendedPacketTypes = struct {
+ mu sync.RWMutex
+ constructors map[string]ExtendedDataConstructor
+}{
+ constructors: make(map[string]ExtendedDataConstructor),
+}
+
+// RegisterExtendedPacketType defines a specific ExtendedDataConstructor for the given extension string.
+func RegisterExtendedPacketType(extension string, constructor ExtendedDataConstructor) {
+ extendedPacketTypes.mu.Lock()
+ defer extendedPacketTypes.mu.Unlock()
+
+ if _, exist := extendedPacketTypes.constructors[extension]; exist {
+ panic("encoding/ssh/filexfer: multiple registration of extended packet type " + extension)
+ }
+
+ extendedPacketTypes.constructors[extension] = constructor
+}
+
+func newExtendedPacket(extension string) ExtendedData {
+ extendedPacketTypes.mu.RLock()
+ defer extendedPacketTypes.mu.RUnlock()
+
+ if f := extendedPacketTypes.constructors[extension]; f != nil {
+ return f()
+ }
+
+ return new(Buffer)
+}
+
+// ExtendedPacket defines the SSH_FXP_CLOSE packet.
+type ExtendedPacket struct {
+ ExtendedRequest string
+
+ Data ExtendedData
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ExtendedPacket) Type() PacketType {
+ return PacketTypeExtended
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+//
+// The Data is marshaled into binary, and returned as the payload.
+func (p *ExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.ExtendedRequest) // string(extended-request)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeExtended, reqid)
+ buf.AppendString(p.ExtendedRequest)
+
+ if p.Data != nil {
+ payload, err = p.Data.MarshalBinary()
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+//
+// If p.Data is nil, and the extension has been registered, a new type will be made from the registration.
+// If the extension has not been registered, then a new Buffer will be allocated.
+// Then the request-specific-data will be unmarshaled from the rest of the buffer.
+func (p *ExtendedPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.ExtendedRequest, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if p.Data == nil {
+ p.Data = newExtendedPacket(p.ExtendedRequest)
+ }
+
+ return p.Data.UnmarshalBinary(buf.Bytes())
+}
+
+// ExtendedReplyPacket defines the SSH_FXP_CLOSE packet.
+type ExtendedReplyPacket struct {
+ Data ExtendedData
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ExtendedReplyPacket) Type() PacketType {
+ return PacketTypeExtendedReply
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+//
+// The Data is marshaled into binary, and returned as the payload.
+func (p *ExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ buf = NewMarshalBuffer(0)
+ }
+
+ buf.StartPacket(PacketTypeExtendedReply, reqid)
+
+ if p.Data != nil {
+ payload, err = p.Data.MarshalBinary()
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+//
+// If p.Data is nil, and there is request-specific-data,
+// then the request-specific-data will be wrapped in a Buffer and assigned to p.Data.
+func (p *ExtendedReplyPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Data == nil {
+ p.Data = new(Buffer)
+ }
+
+ return p.Data.UnmarshalBinary(buf.Bytes())
+}
diff --git a/internal/encoding/ssh/filexfer/extended_packets_test.go b/internal/encoding/ssh/filexfer/extended_packets_test.go
new file mode 100644
index 0000000..0860773
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/extended_packets_test.go
@@ -0,0 +1,240 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+type testExtendedData struct {
+ value uint8
+}
+
+func (d *testExtendedData) MarshalBinary() ([]byte, error) {
+ buf := NewBuffer(make([]byte, 0, 4))
+
+ buf.AppendUint8(d.value ^ 0x2a)
+
+ return buf.Bytes(), nil
+}
+
+func (d *testExtendedData) UnmarshalBinary(data []byte) error {
+ buf := NewBuffer(data)
+
+ v, err := buf.ConsumeUint8()
+ if err != nil {
+ return err
+ }
+
+ d.value = v ^ 0x2a
+
+ return nil
+}
+
+var _ Packet = &ExtendedPacket{}
+
+func TestExtendedPacketNoData(t *testing.T) {
+ const (
+ id = 42
+ extendedRequest = "foo@example"
+ )
+
+ p := &ExtendedPacket{
+ ExtendedRequest: extendedRequest,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 20,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 11, 'f', 'o', 'o', '@', 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ExtendedPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extendedRequest {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest)
+ }
+}
+
+func TestExtendedPacketTestData(t *testing.T) {
+ const (
+ id = 42
+ extendedRequest = "foo@example"
+ textValue = 13
+ )
+
+ const value = 13
+
+ p := &ExtendedPacket{
+ ExtendedRequest: extendedRequest,
+ Data: &testExtendedData{
+ value: textValue,
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 21,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 11, 'f', 'o', 'o', '@', 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ 0x27,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ExtendedPacket{
+ Data: new(testExtendedData),
+ }
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extendedRequest {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest)
+ }
+
+ if buf, ok := p.Data.(*testExtendedData); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
+
+ } else if buf.value != value {
+ t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", buf.value, value)
+ }
+
+ *p = ExtendedPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extendedRequest {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest)
+ }
+
+ wantBuffer := []byte{0x27}
+
+ if buf, ok := p.Data.(*Buffer); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
+
+ } else if !bytes.Equal(buf.b, wantBuffer) {
+ t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", buf.b, wantBuffer)
+ }
+}
+
+var _ Packet = &ExtendedReplyPacket{}
+
+func TestExtendedReplyNoData(t *testing.T) {
+ const (
+ id = 42
+ )
+
+ p := &ExtendedReplyPacket{}
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 5,
+ 201,
+ 0x00, 0x00, 0x00, 42,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ExtendedReplyPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+}
+
+func TestExtendedReplyPacketTestData(t *testing.T) {
+ const (
+ id = 42
+ textValue = 13
+ )
+
+ const value = 13
+
+ p := &ExtendedReplyPacket{
+ Data: &testExtendedData{
+ value: textValue,
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 6,
+ 201,
+ 0x00, 0x00, 0x00, 42,
+ 0x27,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ExtendedReplyPacket{
+ Data: new(testExtendedData),
+ }
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if buf, ok := p.Data.(*testExtendedData); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
+
+ } else if buf.value != value {
+ t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", buf.value, value)
+ }
+
+ *p = ExtendedReplyPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ wantBuffer := []byte{0x27}
+
+ if buf, ok := p.Data.(*Buffer); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
+
+ } else if !bytes.Equal(buf.b, wantBuffer) {
+ t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", buf.b, wantBuffer)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/extensions.go b/internal/encoding/ssh/filexfer/extensions.go
new file mode 100644
index 0000000..11c0b99
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/extensions.go
@@ -0,0 +1,46 @@
+package filexfer
+
+// ExtensionPair defines the extension-pair type defined in draft-ietf-secsh-filexfer-13.
+// This type is backwards-compatible with how draft-ietf-secsh-filexfer-02 defines extensions.
+//
+// Defined in: https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-4.2
+type ExtensionPair struct {
+ Name string
+ Data string
+}
+
+// Len returns the number of bytes e would marshal into.
+func (e *ExtensionPair) Len() int {
+ return 4 + len(e.Name) + 4 + len(e.Data)
+}
+
+// MarshalInto marshals e onto the end of the given Buffer.
+func (e *ExtensionPair) MarshalInto(buf *Buffer) {
+ buf.AppendString(e.Name)
+ buf.AppendString(e.Data)
+}
+
+// MarshalBinary returns e as the binary encoding of e.
+func (e *ExtensionPair) MarshalBinary() ([]byte, error) {
+ buf := NewBuffer(make([]byte, 0, e.Len()))
+ e.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom unmarshals an ExtensionPair from the given Buffer into e.
+func (e *ExtensionPair) UnmarshalFrom(buf *Buffer) (err error) {
+ if e.Name, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if e.Data, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the binary encoding of ExtensionPair into e.
+func (e *ExtensionPair) UnmarshalBinary(data []byte) error {
+ return e.UnmarshalFrom(NewBuffer(data))
+}
diff --git a/internal/encoding/ssh/filexfer/extensions_test.go b/internal/encoding/ssh/filexfer/extensions_test.go
new file mode 100644
index 0000000..453265b
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/extensions_test.go
@@ -0,0 +1,49 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestExtensionPair(t *testing.T) {
+ const (
+ name = "foo"
+ data = "1"
+ )
+
+ pair := &ExtensionPair{
+ Name: name,
+ Data: data,
+ }
+
+ buf, err := pair.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 3,
+ 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 1,
+ '1',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Errorf("ExtensionPair.MarshalBinary() = %X, but wanted %X", buf, want)
+ }
+
+ *pair = ExtensionPair{}
+
+ if err := pair.UnmarshalBinary(buf); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if pair.Name != name {
+ t.Errorf("ExtensionPair.UnmarshalBinary(): Name was %q, but expected %q", pair.Name, name)
+ }
+
+ if pair.Data != data {
+ t.Errorf("RawPacket.UnmarshalBinary(): Data was %q, but expected %q", pair.Data, data)
+ }
+
+}
diff --git a/internal/encoding/ssh/filexfer/filexfer.go b/internal/encoding/ssh/filexfer/filexfer.go
new file mode 100644
index 0000000..1e5abf7
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/filexfer.go
@@ -0,0 +1,54 @@
+// Package filexfer implements the wire encoding for secsh-filexfer as described in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
+package filexfer
+
+// PacketMarshaller narrowly defines packets that will only be transmitted.
+//
+// ExtendedPacket types will often only implement this interface,
+// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field.
+type PacketMarshaller interface {
+ // MarshalPacket is the primary intended way to encode a packet.
+ // The request-id for the packet is set from reqid.
+ //
+ // An optional buffer may be given in b.
+ // If the buffer has a minimum capacity, it shall be truncated and used to marshal the header into.
+ // The minimum capacity for the packet must be a constant expression, and should be at least 9.
+ //
+ // It shall return the main body of the encoded packet in header,
+ // and may optionally return an additional payload to be written immediately after the header.
+ //
+ // It shall encode in the first 4-bytes of the header the proper length of the rest of the header+payload.
+ MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error)
+}
+
+// Packet defines the behavior of a full generic SFTP packet.
+//
+// InitPacket, and VersionPacket are not generic SFTP packets, and instead implement (Un)MarshalBinary.
+//
+// ExtendedPacket types should not iplement this interface,
+// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field.
+type Packet interface {
+ PacketMarshaller
+
+ // Type returns the SSH_FXP_xy value associated with the specific packet.
+ Type() PacketType
+
+ // UnmarshalPacketBody decodes a packet body from the given Buffer.
+ // It is assumed that the common header values of the length, type and request-id have already been consumed.
+ //
+ // Implementations should not alias the given Buffer,
+ // instead they can consider prepopulating an internal buffer as a hint,
+ // and copying into that buffer if it has sufficient length.
+ UnmarshalPacketBody(buf *Buffer) error
+}
+
+// ComposePacket converts returns from MarshalPacket into an equivalent call to MarshalBinary.
+func ComposePacket(header, payload []byte, err error) ([]byte, error) {
+ return append(header, payload...), err
+}
+
+// Default length values,
+// Defined in draft-ietf-secsh-filexfer-02 section 3.
+const (
+ DefaultMaxPacketLength = 34000
+ DefaultMaxDataLength = 32768
+)
diff --git a/internal/encoding/ssh/filexfer/fx.go b/internal/encoding/ssh/filexfer/fx.go
new file mode 100644
index 0000000..48f8698
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/fx.go
@@ -0,0 +1,147 @@
+package filexfer
+
+import (
+ "fmt"
+)
+
+// Status defines the SFTP error codes used in SSH_FXP_STATUS response packets.
+type Status uint32
+
+// Defines the various SSH_FX_* values.
+const (
+ // see draft-ietf-secsh-filexfer-02
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-7
+ StatusOK = Status(iota)
+ StatusEOF
+ StatusNoSuchFile
+ StatusPermissionDenied
+ StatusFailure
+ StatusBadMessage
+ StatusNoConnection
+ StatusConnectionLost
+ StatusOPUnsupported
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-03#section-7
+ StatusV4InvalidHandle
+ StatusV4NoSuchPath
+ StatusV4FileAlreadyExists
+ StatusV4WriteProtect
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-04#section-7
+ StatusV4NoMedia
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-05#section-7
+ StatusV5NoSpaceOnFilesystem
+ StatusV5QuotaExceeded
+ StatusV5UnknownPrincipal
+ StatusV5LockConflict
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-06#section-8
+ StatusV6DirNotEmpty
+ StatusV6NotADirectory
+ StatusV6InvalidFilename
+ StatusV6LinkLoop
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-07#section-8
+ StatusV6CannotDelete
+ StatusV6InvalidParameter
+ StatusV6FileIsADirectory
+ StatusV6ByteRangeLockConflict
+ StatusV6ByteRangeLockRefused
+ StatusV6DeletePending
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-08#section-8.1
+ StatusV6FileCorrupt
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-10#section-9.1
+ StatusV6OwnerInvalid
+ StatusV6GroupInvalid
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1
+ StatusV6NoMatchingByteRangeLock
+)
+
+func (s Status) Error() string {
+ return s.String()
+}
+
+// Is returns true if the target is the same Status code,
+// or target is a StatusPacket with the same Status code.
+func (s Status) Is(target error) bool {
+ if target, ok := target.(*StatusPacket); ok {
+ return target.StatusCode == s
+ }
+
+ return s == target
+}
+
+func (s Status) String() string {
+ switch s {
+ case StatusOK:
+ return "SSH_FX_OK"
+ case StatusEOF:
+ return "SSH_FX_EOF"
+ case StatusNoSuchFile:
+ return "SSH_FX_NO_SUCH_FILE"
+ case StatusPermissionDenied:
+ return "SSH_FX_PERMISSION_DENIED"
+ case StatusFailure:
+ return "SSH_FX_FAILURE"
+ case StatusBadMessage:
+ return "SSH_FX_BAD_MESSAGE"
+ case StatusNoConnection:
+ return "SSH_FX_NO_CONNECTION"
+ case StatusConnectionLost:
+ return "SSH_FX_CONNECTION_LOST"
+ case StatusOPUnsupported:
+ return "SSH_FX_OP_UNSUPPORTED"
+ case StatusV4InvalidHandle:
+ return "SSH_FX_INVALID_HANDLE"
+ case StatusV4NoSuchPath:
+ return "SSH_FX_NO_SUCH_PATH"
+ case StatusV4FileAlreadyExists:
+ return "SSH_FX_FILE_ALREADY_EXISTS"
+ case StatusV4WriteProtect:
+ return "SSH_FX_WRITE_PROTECT"
+ case StatusV4NoMedia:
+ return "SSH_FX_NO_MEDIA"
+ case StatusV5NoSpaceOnFilesystem:
+ return "SSH_FX_NO_SPACE_ON_FILESYSTEM"
+ case StatusV5QuotaExceeded:
+ return "SSH_FX_QUOTA_EXCEEDED"
+ case StatusV5UnknownPrincipal:
+ return "SSH_FX_UNKNOWN_PRINCIPAL"
+ case StatusV5LockConflict:
+ return "SSH_FX_LOCK_CONFLICT"
+ case StatusV6DirNotEmpty:
+ return "SSH_FX_DIR_NOT_EMPTY"
+ case StatusV6NotADirectory:
+ return "SSH_FX_NOT_A_DIRECTORY"
+ case StatusV6InvalidFilename:
+ return "SSH_FX_INVALID_FILENAME"
+ case StatusV6LinkLoop:
+ return "SSH_FX_LINK_LOOP"
+ case StatusV6CannotDelete:
+ return "SSH_FX_CANNOT_DELETE"
+ case StatusV6InvalidParameter:
+ return "SSH_FX_INVALID_PARAMETER"
+ case StatusV6FileIsADirectory:
+ return "SSH_FX_FILE_IS_A_DIRECTORY"
+ case StatusV6ByteRangeLockConflict:
+ return "SSH_FX_BYTE_RANGE_LOCK_CONFLICT"
+ case StatusV6ByteRangeLockRefused:
+ return "SSH_FX_BYTE_RANGE_LOCK_REFUSED"
+ case StatusV6DeletePending:
+ return "SSH_FX_DELETE_PENDING"
+ case StatusV6FileCorrupt:
+ return "SSH_FX_FILE_CORRUPT"
+ case StatusV6OwnerInvalid:
+ return "SSH_FX_OWNER_INVALID"
+ case StatusV6GroupInvalid:
+ return "SSH_FX_GROUP_INVALID"
+ case StatusV6NoMatchingByteRangeLock:
+ return "SSH_FX_NO_MATCHING_BYTE_RANGE_LOCK"
+ default:
+ return fmt.Sprintf("SSH_FX_UNKNOWN(%d)", s)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/fx_test.go b/internal/encoding/ssh/filexfer/fx_test.go
new file mode 100644
index 0000000..3e8db1d
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/fx_test.go
@@ -0,0 +1,102 @@
+package filexfer
+
+import (
+ "bufio"
+ "errors"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+// This string data is copied verbatim from https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13
+var fxStandardsText = `
+SSH_FX_OK 0
+SSH_FX_EOF 1
+SSH_FX_NO_SUCH_FILE 2
+SSH_FX_PERMISSION_DENIED 3
+SSH_FX_FAILURE 4
+SSH_FX_BAD_MESSAGE 5
+SSH_FX_NO_CONNECTION 6
+SSH_FX_CONNECTION_LOST 7
+SSH_FX_OP_UNSUPPORTED 8
+SSH_FX_INVALID_HANDLE 9
+SSH_FX_NO_SUCH_PATH 10
+SSH_FX_FILE_ALREADY_EXISTS 11
+SSH_FX_WRITE_PROTECT 12
+SSH_FX_NO_MEDIA 13
+SSH_FX_NO_SPACE_ON_FILESYSTEM 14
+SSH_FX_QUOTA_EXCEEDED 15
+SSH_FX_UNKNOWN_PRINCIPAL 16
+SSH_FX_LOCK_CONFLICT 17
+SSH_FX_DIR_NOT_EMPTY 18
+SSH_FX_NOT_A_DIRECTORY 19
+SSH_FX_INVALID_FILENAME 20
+SSH_FX_LINK_LOOP 21
+SSH_FX_CANNOT_DELETE 22
+SSH_FX_INVALID_PARAMETER 23
+SSH_FX_FILE_IS_A_DIRECTORY 24
+SSH_FX_BYTE_RANGE_LOCK_CONFLICT 25
+SSH_FX_BYTE_RANGE_LOCK_REFUSED 26
+SSH_FX_DELETE_PENDING 27
+SSH_FX_FILE_CORRUPT 28
+SSH_FX_OWNER_INVALID 29
+SSH_FX_GROUP_INVALID 30
+SSH_FX_NO_MATCHING_BYTE_RANGE_LOCK 31
+`
+
+func TestFxNames(t *testing.T) {
+ whitespace := regexp.MustCompile(`[[:space:]]+`)
+
+ scan := bufio.NewScanner(strings.NewReader(fxStandardsText))
+
+ for scan.Scan() {
+ line := scan.Text()
+ if i := strings.Index(line, "//"); i >= 0 {
+ line = line[:i]
+ }
+
+ line = strings.TrimSpace(line)
+ if line == "" {
+ continue
+ }
+
+ fields := whitespace.Split(line, 2)
+ if len(fields) < 2 {
+ t.Fatalf("unexpected standards text line: %q", line)
+ }
+
+ name, value := fields[0], fields[1]
+ n, err := strconv.Atoi(value)
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ fx := Status(n)
+
+ if got := fx.String(); got != name {
+ t.Errorf("fx name mismatch for %d: got %q, but want %q", n, got, name)
+ }
+ }
+
+ if err := scan.Err(); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+}
+
+func TestStatusIs(t *testing.T) {
+ status := StatusFailure
+
+ if !errors.Is(status, StatusFailure) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) != true")
+ }
+ if !errors.Is(status, &StatusPacket{StatusCode: StatusFailure}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) != true")
+ }
+ if errors.Is(status, StatusOK) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) == true")
+ }
+ if errors.Is(status, &StatusPacket{StatusCode: StatusOK}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) == true")
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/fxp.go b/internal/encoding/ssh/filexfer/fxp.go
new file mode 100644
index 0000000..15caf6d
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/fxp.go
@@ -0,0 +1,124 @@
+package filexfer
+
+import (
+ "fmt"
+)
+
+// PacketType defines the various SFTP packet types.
+type PacketType uint8
+
+// Request packet types.
+const (
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-3
+ PacketTypeInit = PacketType(iota + 1)
+ PacketTypeVersion
+ PacketTypeOpen
+ PacketTypeClose
+ PacketTypeRead
+ PacketTypeWrite
+ PacketTypeLStat
+ PacketTypeFStat
+ PacketTypeSetstat
+ PacketTypeFSetstat
+ PacketTypeOpenDir
+ PacketTypeReadDir
+ PacketTypeRemove
+ PacketTypeMkdir
+ PacketTypeRmdir
+ PacketTypeRealPath
+ PacketTypeStat
+ PacketTypeRename
+ PacketTypeReadLink
+ PacketTypeSymlink
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-07#section-3.3
+ PacketTypeV6Link
+
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-08#section-3.3
+ PacketTypeV6Block
+ PacketTypeV6Unblock
+)
+
+// Response packet types.
+const (
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-3
+ PacketTypeStatus = PacketType(iota + 101)
+ PacketTypeHandle
+ PacketTypeData
+ PacketTypeName
+ PacketTypeAttrs
+)
+
+// Extended packet types.
+const (
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-3
+ PacketTypeExtended = PacketType(iota + 200)
+ PacketTypeExtendedReply
+)
+
+func (f PacketType) String() string {
+ switch f {
+ case PacketTypeInit:
+ return "SSH_FXP_INIT"
+ case PacketTypeVersion:
+ return "SSH_FXP_VERSION"
+ case PacketTypeOpen:
+ return "SSH_FXP_OPEN"
+ case PacketTypeClose:
+ return "SSH_FXP_CLOSE"
+ case PacketTypeRead:
+ return "SSH_FXP_READ"
+ case PacketTypeWrite:
+ return "SSH_FXP_WRITE"
+ case PacketTypeLStat:
+ return "SSH_FXP_LSTAT"
+ case PacketTypeFStat:
+ return "SSH_FXP_FSTAT"
+ case PacketTypeSetstat:
+ return "SSH_FXP_SETSTAT"
+ case PacketTypeFSetstat:
+ return "SSH_FXP_FSETSTAT"
+ case PacketTypeOpenDir:
+ return "SSH_FXP_OPENDIR"
+ case PacketTypeReadDir:
+ return "SSH_FXP_READDIR"
+ case PacketTypeRemove:
+ return "SSH_FXP_REMOVE"
+ case PacketTypeMkdir:
+ return "SSH_FXP_MKDIR"
+ case PacketTypeRmdir:
+ return "SSH_FXP_RMDIR"
+ case PacketTypeRealPath:
+ return "SSH_FXP_REALPATH"
+ case PacketTypeStat:
+ return "SSH_FXP_STAT"
+ case PacketTypeRename:
+ return "SSH_FXP_RENAME"
+ case PacketTypeReadLink:
+ return "SSH_FXP_READLINK"
+ case PacketTypeSymlink:
+ return "SSH_FXP_SYMLINK"
+ case PacketTypeV6Link:
+ return "SSH_FXP_LINK"
+ case PacketTypeV6Block:
+ return "SSH_FXP_BLOCK"
+ case PacketTypeV6Unblock:
+ return "SSH_FXP_UNBLOCK"
+ case PacketTypeStatus:
+ return "SSH_FXP_STATUS"
+ case PacketTypeHandle:
+ return "SSH_FXP_HANDLE"
+ case PacketTypeData:
+ return "SSH_FXP_DATA"
+ case PacketTypeName:
+ return "SSH_FXP_NAME"
+ case PacketTypeAttrs:
+ return "SSH_FXP_ATTRS"
+ case PacketTypeExtended:
+ return "SSH_FXP_EXTENDED"
+ case PacketTypeExtendedReply:
+ return "SSH_FXP_EXTENDED_REPLY"
+ default:
+ return fmt.Sprintf("SSH_FXP_UNKNOWN(%d)", f)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/fxp_test.go b/internal/encoding/ssh/filexfer/fxp_test.go
new file mode 100644
index 0000000..14e8ff7
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/fxp_test.go
@@ -0,0 +1,84 @@
+package filexfer
+
+import (
+ "bufio"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+// This string data is copied verbatim from https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13
+var fxpStandardsText = `
+SSH_FXP_INIT 1
+SSH_FXP_VERSION 2
+SSH_FXP_OPEN 3
+SSH_FXP_CLOSE 4
+SSH_FXP_READ 5
+SSH_FXP_WRITE 6
+SSH_FXP_LSTAT 7
+SSH_FXP_FSTAT 8
+SSH_FXP_SETSTAT 9
+SSH_FXP_FSETSTAT 10
+SSH_FXP_OPENDIR 11
+SSH_FXP_READDIR 12
+SSH_FXP_REMOVE 13
+SSH_FXP_MKDIR 14
+SSH_FXP_RMDIR 15
+SSH_FXP_REALPATH 16
+SSH_FXP_STAT 17
+SSH_FXP_RENAME 18
+SSH_FXP_READLINK 19
+SSH_FXP_SYMLINK 20 // Deprecated in filexfer-13 added from filexfer-02
+SSH_FXP_LINK 21
+SSH_FXP_BLOCK 22
+SSH_FXP_UNBLOCK 23
+
+SSH_FXP_STATUS 101
+SSH_FXP_HANDLE 102
+SSH_FXP_DATA 103
+SSH_FXP_NAME 104
+SSH_FXP_ATTRS 105
+
+SSH_FXP_EXTENDED 200
+SSH_FXP_EXTENDED_REPLY 201
+`
+
+func TestFxpNames(t *testing.T) {
+ whitespace := regexp.MustCompile(`[[:space:]]+`)
+
+ scan := bufio.NewScanner(strings.NewReader(fxpStandardsText))
+
+ for scan.Scan() {
+ line := scan.Text()
+ if i := strings.Index(line, "//"); i >= 0 {
+ line = line[:i]
+ }
+
+ line = strings.TrimSpace(line)
+ if line == "" {
+ continue
+ }
+
+ fields := whitespace.Split(line, 2)
+ if len(fields) < 2 {
+ t.Fatalf("unexpected standards text line: %q", line)
+ }
+
+ name, value := fields[0], fields[1]
+ n, err := strconv.Atoi(value)
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ fxp := PacketType(n)
+
+ if got := fxp.String(); got != name {
+ t.Errorf("fxp name mismatch for %d: got %q, but want %q", n, got, name)
+ }
+ }
+
+ if err := scan.Err(); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/handle_packets.go b/internal/encoding/ssh/filexfer/handle_packets.go
new file mode 100644
index 0000000..a142771
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/handle_packets.go
@@ -0,0 +1,249 @@
+package filexfer
+
+// ClosePacket defines the SSH_FXP_CLOSE packet.
+type ClosePacket struct {
+ Handle string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ClosePacket) Type() PacketType {
+ return PacketTypeClose
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *ClosePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Handle) // string(handle)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeClose, reqid)
+ buf.AppendString(p.Handle)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *ClosePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// ReadPacket defines the SSH_FXP_READ packet.
+type ReadPacket struct {
+ Handle string
+ Offset uint64
+ Len uint32
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ReadPacket) Type() PacketType {
+ return PacketTypeRead
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *ReadPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ // string(handle) + uint64(offset) + uint32(len)
+ size := 4 + len(p.Handle) + 8 + 4
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeRead, reqid)
+ buf.AppendString(p.Handle)
+ buf.AppendUint64(p.Offset)
+ buf.AppendUint32(p.Len)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *ReadPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if p.Offset, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+
+ if p.Len, err = buf.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WritePacket defines the SSH_FXP_WRITE packet.
+type WritePacket struct {
+ Handle string
+ Offset uint64
+ Data []byte
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *WritePacket) Type() PacketType {
+ return PacketTypeWrite
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *WritePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ // string(handle) + uint64(offset) + uint32(len(data)); data content in payload
+ size := 4 + len(p.Handle) + 8 + 4
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeWrite, reqid)
+ buf.AppendString(p.Handle)
+ buf.AppendUint64(p.Offset)
+ buf.AppendUint32(uint32(len(p.Data)))
+
+ return buf.Packet(p.Data)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+//
+// If p.Data is already populated, and of sufficient length to hold the data,
+// then this will copy the data into that byte slice.
+//
+// If p.Data has a length insufficient to hold the data,
+// then this will make a new slice of sufficient length, and copy the data into that.
+//
+// This means this _does not_ alias any of the data buffer that is passed in.
+func (p *WritePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if p.Offset, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+
+ data, err := buf.ConsumeByteSlice()
+ if err != nil {
+ return err
+ }
+
+ if len(p.Data) < len(data) {
+ p.Data = make([]byte, len(data))
+ }
+
+ n := copy(p.Data, data)
+ p.Data = p.Data[:n]
+ return nil
+}
+
+// FStatPacket defines the SSH_FXP_FSTAT packet.
+type FStatPacket struct {
+ Handle string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *FStatPacket) Type() PacketType {
+ return PacketTypeFStat
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *FStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Handle) // string(handle)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeFStat, reqid)
+ buf.AppendString(p.Handle)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *FStatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// FSetstatPacket defines the SSH_FXP_FSETSTAT packet.
+type FSetstatPacket struct {
+ Handle string
+ Attrs Attributes
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *FSetstatPacket) Type() PacketType {
+ return PacketTypeFSetstat
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *FSetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Handle) + p.Attrs.Len() // string(handle) + ATTRS(attrs)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeFSetstat, reqid)
+ buf.AppendString(p.Handle)
+
+ p.Attrs.MarshalInto(buf)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *FSetstatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return p.Attrs.UnmarshalFrom(buf)
+}
+
+// ReadDirPacket defines the SSH_FXP_READDIR packet.
+type ReadDirPacket struct {
+ Handle string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ReadDirPacket) Type() PacketType {
+ return PacketTypeReadDir
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *ReadDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Handle) // string(handle)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeReadDir, reqid)
+ buf.AppendString(p.Handle)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *ReadDirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/encoding/ssh/filexfer/handle_packets_test.go b/internal/encoding/ssh/filexfer/handle_packets_test.go
new file mode 100644
index 0000000..8cc394e
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/handle_packets_test.go
@@ -0,0 +1,282 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+var _ Packet = &ClosePacket{}
+
+func TestClosePacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ )
+
+ p := &ClosePacket{
+ Handle: "somehandle",
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 19,
+ 4,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ClosePacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle)
+ }
+}
+
+var _ Packet = &ReadPacket{}
+
+func TestReadPacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ offset = 0x123456789ABCDEF0
+ length = 0xFEDCBA98
+ )
+
+ p := &ReadPacket{
+ Handle: "somehandle",
+ Offset: offset,
+ Len: length,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 31,
+ 5,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0,
+ 0xFE, 0xDC, 0xBA, 0x98,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ReadPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle)
+ }
+
+ if p.Offset != offset {
+ t.Errorf("UnmarshalPacketBody(): Offset was %x, but expected %x", p.Offset, offset)
+ }
+
+ if p.Len != length {
+ t.Errorf("UnmarshalPacketBody(): Len was %x, but expected %x", p.Len, length)
+ }
+}
+
+var _ Packet = &WritePacket{}
+
+func TestWritePacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ offset = 0x123456789ABCDEF0
+ )
+
+ var payload = []byte(`foobar`)
+
+ p := &WritePacket{
+ Handle: "somehandle",
+ Offset: offset,
+ Data: payload,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 37,
+ 6,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0,
+ 0x00, 0x00, 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = WritePacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle)
+ }
+
+ if p.Offset != offset {
+ t.Errorf("UnmarshalPacketBody(): Offset was %x, but expected %x", p.Offset, offset)
+ }
+
+ if !bytes.Equal(p.Data, payload) {
+ t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", p.Data, payload)
+ }
+}
+
+var _ Packet = &FStatPacket{}
+
+func TestFStatPacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ )
+
+ p := &FStatPacket{
+ Handle: "somehandle",
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 19,
+ 8,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = FStatPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle)
+ }
+}
+
+var _ Packet = &FSetstatPacket{}
+
+func TestFSetstatPacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ perms = 0x87654321
+ )
+
+ p := &FSetstatPacket{
+ Handle: "somehandle",
+ Attrs: Attributes{
+ Flags: AttrPermissions,
+ Permissions: perms,
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 27,
+ 10,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ 0x00, 0x00, 0x00, 0x04,
+ 0x87, 0x65, 0x43, 0x21,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = FSetstatPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle)
+ }
+}
+
+var _ Packet = &ReadDirPacket{}
+
+func TestReadDirPacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ )
+
+ p := &ReadDirPacket{
+ Handle: "somehandle",
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 19,
+ 12,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ReadDirPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/init_packets.go b/internal/encoding/ssh/filexfer/init_packets.go
new file mode 100644
index 0000000..b0bc6f5
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/init_packets.go
@@ -0,0 +1,99 @@
+package filexfer
+
+// InitPacket defines the SSH_FXP_INIT packet.
+type InitPacket struct {
+ Version uint32
+ Extensions []*ExtensionPair
+}
+
+// MarshalBinary returns p as the binary encoding of p.
+func (p *InitPacket) MarshalBinary() ([]byte, error) {
+ size := 1 + 4 // byte(type) + uint32(version)
+
+ for _, ext := range p.Extensions {
+ size += ext.Len()
+ }
+
+ b := NewBuffer(make([]byte, 4, 4+size))
+ b.AppendUint8(uint8(PacketTypeInit))
+ b.AppendUint32(p.Version)
+
+ for _, ext := range p.Extensions {
+ ext.MarshalInto(b)
+ }
+
+ b.PutLength(size)
+
+ return b.Bytes(), nil
+}
+
+// UnmarshalBinary unmarshals a full raw packet out of the given data.
+// It is assumed that the uint32(length) has already been consumed to receive the data.
+// It is also assumed that the uint8(type) has already been consumed to which packet to unmarshal into.
+func (p *InitPacket) UnmarshalBinary(data []byte) (err error) {
+ buf := NewBuffer(data)
+
+ if p.Version, err = buf.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ for buf.Len() > 0 {
+ var ext ExtensionPair
+ if err := ext.UnmarshalFrom(buf); err != nil {
+ return err
+ }
+
+ p.Extensions = append(p.Extensions, &ext)
+ }
+
+ return nil
+}
+
+// VersionPacket defines the SSH_FXP_VERSION packet.
+type VersionPacket struct {
+ Version uint32
+ Extensions []*ExtensionPair
+}
+
+// MarshalBinary returns p as the binary encoding of p.
+func (p *VersionPacket) MarshalBinary() ([]byte, error) {
+ size := 1 + 4 // byte(type) + uint32(version)
+
+ for _, ext := range p.Extensions {
+ size += ext.Len()
+ }
+
+ b := NewBuffer(make([]byte, 4, 4+size))
+ b.AppendUint8(uint8(PacketTypeVersion))
+ b.AppendUint32(p.Version)
+
+ for _, ext := range p.Extensions {
+ ext.MarshalInto(b)
+ }
+
+ b.PutLength(size)
+
+ return b.Bytes(), nil
+}
+
+// UnmarshalBinary unmarshals a full raw packet out of the given data.
+// It is assumed that the uint32(length) has already been consumed to receive the data.
+// It is also assumed that the uint8(type) has already been consumed to which packet to unmarshal into.
+func (p *VersionPacket) UnmarshalBinary(data []byte) (err error) {
+ buf := NewBuffer(data)
+
+ if p.Version, err = buf.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ for buf.Len() > 0 {
+ var ext ExtensionPair
+ if err := ext.UnmarshalFrom(buf); err != nil {
+ return err
+ }
+
+ p.Extensions = append(p.Extensions, &ext)
+ }
+
+ return nil
+}
diff --git a/internal/encoding/ssh/filexfer/init_packets_test.go b/internal/encoding/ssh/filexfer/init_packets_test.go
new file mode 100644
index 0000000..e7605f9
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/init_packets_test.go
@@ -0,0 +1,114 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestInitPacket(t *testing.T) {
+ var version uint8 = 3
+
+ p := &InitPacket{
+ Version: uint32(version),
+ Extensions: []*ExtensionPair{
+ {
+ Name: "foo",
+ Data: "1",
+ },
+ },
+ }
+
+ buf, err := p.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 17,
+ 1,
+ 0x00, 0x00, 0x00, version,
+ 0x00, 0x00, 0x00, 3, 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 1, '1',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want)
+ }
+
+ *p = InitPacket{}
+
+ // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
+ if err := p.UnmarshalBinary(buf[5:]); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Version != uint32(version) {
+ t.Errorf("UnmarshalBinary(): Version was %d, but expected %d", p.Version, version)
+ }
+
+ if len(p.Extensions) != 1 {
+ t.Fatalf("UnmarshalBinary(): len(p.Extensions) was %d, but expected %d", len(p.Extensions), 1)
+ }
+
+ if got, want := p.Extensions[0].Name, "foo"; got != want {
+ t.Errorf("UnmarshalBinary(): p.Extensions[0].Name was %q, but expected %q", got, want)
+ }
+
+ if got, want := p.Extensions[0].Data, "1"; got != want {
+ t.Errorf("UnmarshalBinary(): p.Extensions[0].Data was %q, but expected %q", got, want)
+ }
+}
+
+func TestVersionPacket(t *testing.T) {
+ var version uint8 = 3
+
+ p := &VersionPacket{
+ Version: uint32(version),
+ Extensions: []*ExtensionPair{
+ {
+ Name: "foo",
+ Data: "1",
+ },
+ },
+ }
+
+ buf, err := p.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 17,
+ 2,
+ 0x00, 0x00, 0x00, version,
+ 0x00, 0x00, 0x00, 3, 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 1, '1',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want)
+ }
+
+ *p = VersionPacket{}
+
+ // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
+ if err := p.UnmarshalBinary(buf[5:]); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Version != uint32(version) {
+ t.Errorf("UnmarshalBinary(): Version was %d, but expected %d", p.Version, version)
+ }
+
+ if len(p.Extensions) != 1 {
+ t.Fatalf("UnmarshalBinary(): len(p.Extensions) was %d, but expected %d", len(p.Extensions), 1)
+ }
+
+ if got, want := p.Extensions[0].Name, "foo"; got != want {
+ t.Errorf("UnmarshalBinary(): p.Extensions[0].Name was %q, but expected %q", got, want)
+ }
+
+ if got, want := p.Extensions[0].Data, "1"; got != want {
+ t.Errorf("UnmarshalBinary(): p.Extensions[0].Data was %q, but expected %q", got, want)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/open_packets.go b/internal/encoding/ssh/filexfer/open_packets.go
new file mode 100644
index 0000000..1358711
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/open_packets.go
@@ -0,0 +1,89 @@
+package filexfer
+
+// SSH_FXF_* flags.
+const (
+ FlagRead = 1 << iota // SSH_FXF_READ
+ FlagWrite // SSH_FXF_WRITE
+ FlagAppend // SSH_FXF_APPEND
+ FlagCreate // SSH_FXF_CREAT
+ FlagTruncate // SSH_FXF_TRUNC
+ FlagExclusive // SSH_FXF_EXCL
+)
+
+// OpenPacket defines the SSH_FXP_OPEN packet.
+type OpenPacket struct {
+ Filename string
+ PFlags uint32
+ Attrs Attributes
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *OpenPacket) Type() PacketType {
+ return PacketTypeOpen
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *OpenPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ // string(filename) + uint32(pflags) + ATTRS(attrs)
+ size := 4 + len(p.Filename) + 4 + p.Attrs.Len()
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeOpen, reqid)
+ buf.AppendString(p.Filename)
+ buf.AppendUint32(p.PFlags)
+
+ p.Attrs.MarshalInto(buf)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *OpenPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Filename, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if p.PFlags, err = buf.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ return p.Attrs.UnmarshalFrom(buf)
+}
+
+// OpenDirPacket defines the SSH_FXP_OPENDIR packet.
+type OpenDirPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *OpenDirPacket) Type() PacketType {
+ return PacketTypeOpenDir
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *OpenDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) // string(path)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeOpenDir, reqid)
+ buf.AppendString(p.Path)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *OpenDirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/encoding/ssh/filexfer/open_packets_test.go b/internal/encoding/ssh/filexfer/open_packets_test.go
new file mode 100644
index 0000000..560c8b4
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/open_packets_test.go
@@ -0,0 +1,107 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+var _ Packet = &OpenPacket{}
+
+func TestOpenPacket(t *testing.T) {
+ const (
+ id = 42
+ filename = "/foo"
+ perms = 0x87654321
+ )
+
+ p := &OpenPacket{
+ Filename: "/foo",
+ PFlags: FlagRead,
+ Attrs: Attributes{
+ Flags: AttrPermissions,
+ Permissions: perms,
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 25,
+ 3,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 1,
+ 0x00, 0x00, 0x00, 0x04,
+ 0x87, 0x65, 0x43, 0x21,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = OpenPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Filename != filename {
+ t.Errorf("UnmarshalPacketBody(): Filename was %q, but expected %q", p.Filename, filename)
+ }
+
+ if p.PFlags != FlagRead {
+ t.Errorf("UnmarshalPacketBody(): PFlags was %#x, but expected %#x", p.PFlags, FlagRead)
+ }
+
+ if p.Attrs.Flags != AttrPermissions {
+ t.Errorf("UnmarshalPacketBody(): Attrs.Flags was %#x, but expected %#x", p.Attrs.Flags, AttrPermissions)
+ }
+
+ if p.Attrs.Permissions != perms {
+ t.Errorf("UnmarshalPacketBody(): Attrs.Permissions was %#v, but expected %#v", p.Attrs.Permissions, perms)
+ }
+}
+
+var _ Packet = &OpenDirPacket{}
+
+func TestOpenDirPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ p := &OpenDirPacket{
+ Path: path,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 11,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = OpenDirPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/fsync.go b/internal/encoding/ssh/filexfer/openssh/fsync.go
new file mode 100644
index 0000000..7ecfb0c
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/fsync.go
@@ -0,0 +1,73 @@
+package openssh
+
+import (
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+const extensionFSync = "fsync@openssh.com"
+
+// RegisterExtensionFSync registers the "fsync@openssh.com" extended packet with the encoding/ssh/filexfer package.
+func RegisterExtensionFSync() {
+ sshfx.RegisterExtendedPacketType(extensionFSync, func() sshfx.ExtendedData {
+ return new(FSyncExtendedPacket)
+ })
+}
+
+// ExtensionFSync returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
+func ExtensionFSync() *sshfx.ExtensionPair {
+ return &sshfx.ExtensionPair{
+ Name: extensionFSync,
+ Data: "1",
+ }
+}
+
+// FSyncExtendedPacket defines the fsync@openssh.com extend packet.
+type FSyncExtendedPacket struct {
+ Handle string
+}
+
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *FSyncExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
+// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
+func (ep *FSyncExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ p := &sshfx.ExtendedPacket{
+ ExtendedRequest: extensionFSync,
+
+ Data: ep,
+ }
+ return p.MarshalPacket(reqid, b)
+}
+
+// MarshalInto encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data.
+func (ep *FSyncExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
+ buf.AppendString(ep.Handle)
+}
+
+// MarshalBinary encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data.
+//
+// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
+func (ep *FSyncExtendedPacket) MarshalBinary() ([]byte, error) {
+ // string(handle)
+ size := 4 + len(ep.Handle)
+
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
+ ep.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom decodes the fsync@openssh.com extended packet-specific data from buf.
+func (ep *FSyncExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
+ if ep.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the fsync@openssh.com extended packet-specific data into ep.
+func (ep *FSyncExtendedPacket) UnmarshalBinary(data []byte) (err error) {
+ return ep.UnmarshalFrom(sshfx.NewBuffer(data))
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/fsync_test.go b/internal/encoding/ssh/filexfer/openssh/fsync_test.go
new file mode 100644
index 0000000..f9e878f
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/fsync_test.go
@@ -0,0 +1,62 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &FSyncExtendedPacket{}
+
+func init() {
+ RegisterExtensionFSync()
+}
+
+func TestFSyncExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ )
+
+ ep := &FSyncExtendedPacket{
+ Handle: handle,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 40,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 17, 'f', 's', 'y', 'n', 'c', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionFSync {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionFSync)
+ }
+
+ ep, ok := p.Data.(*FSyncExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *FSyncExtendedPacket", p.Data)
+ }
+
+ if ep.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", ep.Handle, handle)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/hardlink.go b/internal/encoding/ssh/filexfer/openssh/hardlink.go
new file mode 100644
index 0000000..17c3499
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/hardlink.go
@@ -0,0 +1,79 @@
+package openssh
+
+import (
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+const extensionHardlink = "hardlink@openssh.com"
+
+// RegisterExtensionHardlink registers the "hardlink@openssh.com" extended packet with the encoding/ssh/filexfer package.
+func RegisterExtensionHardlink() {
+ sshfx.RegisterExtendedPacketType(extensionHardlink, func() sshfx.ExtendedData {
+ return new(HardlinkExtendedPacket)
+ })
+}
+
+// ExtensionHardlink returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
+func ExtensionHardlink() *sshfx.ExtensionPair {
+ return &sshfx.ExtensionPair{
+ Name: extensionHardlink,
+ Data: "1",
+ }
+}
+
+// HardlinkExtendedPacket defines the hardlink@openssh.com extend packet.
+type HardlinkExtendedPacket struct {
+ OldPath string
+ NewPath string
+}
+
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *HardlinkExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
+// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
+func (ep *HardlinkExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ p := &sshfx.ExtendedPacket{
+ ExtendedRequest: extensionHardlink,
+
+ Data: ep,
+ }
+ return p.MarshalPacket(reqid, b)
+}
+
+// MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
+func (ep *HardlinkExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
+ buf.AppendString(ep.OldPath)
+ buf.AppendString(ep.NewPath)
+}
+
+// MarshalBinary encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
+//
+// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
+func (ep *HardlinkExtendedPacket) MarshalBinary() ([]byte, error) {
+ // string(oldpath) + string(newpath)
+ size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath)
+
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
+ ep.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom decodes the hardlink@openssh.com extended packet-specific data from buf.
+func (ep *HardlinkExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
+ if ep.OldPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if ep.NewPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the hardlink@openssh.com extended packet-specific data into ep.
+func (ep *HardlinkExtendedPacket) UnmarshalBinary(data []byte) (err error) {
+ return ep.UnmarshalFrom(sshfx.NewBuffer(data))
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/hardlink_test.go b/internal/encoding/ssh/filexfer/openssh/hardlink_test.go
new file mode 100644
index 0000000..5d3be06
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/hardlink_test.go
@@ -0,0 +1,69 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &HardlinkExtendedPacket{}
+
+func init() {
+ RegisterExtensionHardlink()
+}
+
+func TestHardlinkExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ oldpath = "/foo"
+ newpath = "/bar"
+ )
+
+ ep := &HardlinkExtendedPacket{
+ OldPath: oldpath,
+ NewPath: newpath,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 45,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 20, 'h', 'a', 'r', 'd', 'l', 'i', 'n', 'k', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionHardlink {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionHardlink)
+ }
+
+ ep, ok := p.Data.(*HardlinkExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *HardlinkExtendedPacket", p.Data)
+ }
+
+ if ep.OldPath != oldpath {
+ t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", ep.OldPath, oldpath)
+ }
+
+ if ep.NewPath != newpath {
+ t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", ep.NewPath, newpath)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/openssh.go b/internal/encoding/ssh/filexfer/openssh/openssh.go
new file mode 100644
index 0000000..f93ff17
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/openssh.go
@@ -0,0 +1,2 @@
+// Package openssh implements the openssh secsh-filexfer extensions as described in https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
+package openssh
diff --git a/internal/encoding/ssh/filexfer/openssh/posix-rename.go b/internal/encoding/ssh/filexfer/openssh/posix-rename.go
new file mode 100644
index 0000000..a3d3de5
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/posix-rename.go
@@ -0,0 +1,79 @@
+package openssh
+
+import (
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+const extensionPosixRename = "posix-rename@openssh.com"
+
+// RegisterExtensionPosixRename registers the "posix-rename@openssh.com" extended packet with the encoding/ssh/filexfer package.
+func RegisterExtensionPosixRename() {
+ sshfx.RegisterExtendedPacketType(extensionPosixRename, func() sshfx.ExtendedData {
+ return new(PosixRenameExtendedPacket)
+ })
+}
+
+// ExtensionPosixRename returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
+func ExtensionPosixRename() *sshfx.ExtensionPair {
+ return &sshfx.ExtensionPair{
+ Name: extensionPosixRename,
+ Data: "1",
+ }
+}
+
+// PosixRenameExtendedPacket defines the posix-rename@openssh.com extend packet.
+type PosixRenameExtendedPacket struct {
+ OldPath string
+ NewPath string
+}
+
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *PosixRenameExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
+// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
+func (ep *PosixRenameExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ p := &sshfx.ExtendedPacket{
+ ExtendedRequest: extensionPosixRename,
+
+ Data: ep,
+ }
+ return p.MarshalPacket(reqid, b)
+}
+
+// MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
+func (ep *PosixRenameExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
+ buf.AppendString(ep.OldPath)
+ buf.AppendString(ep.NewPath)
+}
+
+// MarshalBinary encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data.
+//
+// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
+func (ep *PosixRenameExtendedPacket) MarshalBinary() ([]byte, error) {
+ // string(oldpath) + string(newpath)
+ size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath)
+
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
+ ep.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom decodes the hardlink@openssh.com extended packet-specific data from buf.
+func (ep *PosixRenameExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
+ if ep.OldPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if ep.NewPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the hardlink@openssh.com extended packet-specific data into ep.
+func (ep *PosixRenameExtendedPacket) UnmarshalBinary(data []byte) (err error) {
+ return ep.UnmarshalFrom(sshfx.NewBuffer(data))
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go b/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go
new file mode 100644
index 0000000..6bdb10d
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go
@@ -0,0 +1,69 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &PosixRenameExtendedPacket{}
+
+func init() {
+ RegisterExtensionPosixRename()
+}
+
+func TestPosixRenameExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ oldpath = "/foo"
+ newpath = "/bar"
+ )
+
+ ep := &PosixRenameExtendedPacket{
+ OldPath: oldpath,
+ NewPath: newpath,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 49,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 24, 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionPosixRename {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionPosixRename)
+ }
+
+ ep, ok := p.Data.(*PosixRenameExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *PosixRenameExtendedPacket", p.Data)
+ }
+
+ if ep.OldPath != oldpath {
+ t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", ep.OldPath, oldpath)
+ }
+
+ if ep.NewPath != newpath {
+ t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", ep.NewPath, newpath)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/statvfs.go b/internal/encoding/ssh/filexfer/openssh/statvfs.go
new file mode 100644
index 0000000..3e9015f
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/statvfs.go
@@ -0,0 +1,256 @@
+package openssh
+
+import (
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+const extensionStatVFS = "statvfs@openssh.com"
+
+// RegisterExtensionStatVFS registers the "statvfs@openssh.com" extended packet with the encoding/ssh/filexfer package.
+func RegisterExtensionStatVFS() {
+ sshfx.RegisterExtendedPacketType(extensionStatVFS, func() sshfx.ExtendedData {
+ return new(StatVFSExtendedPacket)
+ })
+}
+
+// ExtensionStatVFS returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
+func ExtensionStatVFS() *sshfx.ExtensionPair {
+ return &sshfx.ExtensionPair{
+ Name: extensionStatVFS,
+ Data: "2",
+ }
+}
+
+// StatVFSExtendedPacket defines the statvfs@openssh.com extend packet.
+type StatVFSExtendedPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *StatVFSExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
+// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
+func (ep *StatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ p := &sshfx.ExtendedPacket{
+ ExtendedRequest: extensionStatVFS,
+
+ Data: ep,
+ }
+ return p.MarshalPacket(reqid, b)
+}
+
+// MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
+func (ep *StatVFSExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
+ buf.AppendString(ep.Path)
+}
+
+// MarshalBinary encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
+//
+// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
+func (ep *StatVFSExtendedPacket) MarshalBinary() ([]byte, error) {
+ size := 4 + len(ep.Path) // string(path)
+
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
+
+ ep.MarshalInto(buf)
+
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom decodes the statvfs@openssh.com extended packet-specific data into ep.
+func (ep *StatVFSExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
+ if ep.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the statvfs@openssh.com extended packet-specific data into ep.
+func (ep *StatVFSExtendedPacket) UnmarshalBinary(data []byte) (err error) {
+ return ep.UnmarshalFrom(sshfx.NewBuffer(data))
+}
+
+const extensionFStatVFS = "fstatvfs@openssh.com"
+
+// RegisterExtensionFStatVFS registers the "fstatvfs@openssh.com" extended packet with the encoding/ssh/filexfer package.
+func RegisterExtensionFStatVFS() {
+ sshfx.RegisterExtendedPacketType(extensionFStatVFS, func() sshfx.ExtendedData {
+ return new(FStatVFSExtendedPacket)
+ })
+}
+
+// ExtensionFStatVFS returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
+func ExtensionFStatVFS() *sshfx.ExtensionPair {
+ return &sshfx.ExtensionPair{
+ Name: extensionFStatVFS,
+ Data: "2",
+ }
+}
+
+// FStatVFSExtendedPacket defines the fstatvfs@openssh.com extend packet.
+type FStatVFSExtendedPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *FStatVFSExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
+// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
+func (ep *FStatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ p := &sshfx.ExtendedPacket{
+ ExtendedRequest: extensionFStatVFS,
+
+ Data: ep,
+ }
+ return p.MarshalPacket(reqid, b)
+}
+
+// MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
+func (ep *FStatVFSExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
+ buf.AppendString(ep.Path)
+}
+
+// MarshalBinary encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data.
+//
+// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
+func (ep *FStatVFSExtendedPacket) MarshalBinary() ([]byte, error) {
+ size := 4 + len(ep.Path) // string(path)
+
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
+
+ ep.MarshalInto(buf)
+
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom decodes the statvfs@openssh.com extended packet-specific data into ep.
+func (ep *FStatVFSExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
+ if ep.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the statvfs@openssh.com extended packet-specific data into ep.
+func (ep *FStatVFSExtendedPacket) UnmarshalBinary(data []byte) (err error) {
+ return ep.UnmarshalFrom(sshfx.NewBuffer(data))
+}
+
+// The values for the MountFlags field.
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
+const (
+ MountFlagsReadOnly = 0x1 // SSH_FXE_STATVFS_ST_RDONLY
+ MountFlagsNoSUID = 0x2 // SSH_FXE_STATVFS_ST_NOSUID
+)
+
+// StatVFSExtendedReplyPacket defines the extended reply packet for statvfs@openssh.com and fstatvfs@openssh.com requests.
+type StatVFSExtendedReplyPacket struct {
+ BlockSize uint64 /* f_bsize: file system block size */
+ FragmentSize uint64 /* f_frsize: fundamental fs block size / fagment size */
+ Blocks uint64 /* f_blocks: number of blocks (unit f_frsize) */
+ BlocksFree uint64 /* f_bfree: free blocks in filesystem */
+ BlocksAvail uint64 /* f_bavail: free blocks for non-root */
+ Files uint64 /* f_files: total file inodes */
+ FilesFree uint64 /* f_ffree: free file inodes */
+ FilesAvail uint64 /* f_favail: free file inodes for to non-root */
+ FilesystemID uint64 /* f_fsid: file system id */
+ MountFlags uint64 /* f_flag: bit mask of mount flag values */
+ MaxNameLength uint64 /* f_namemax: maximum filename length */
+}
+
+// Type returns the SSH_FXP_EXTENDED_REPLY packet type.
+func (ep *StatVFSExtendedReplyPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtendedReply
+}
+
+// MarshalPacket returns ep as a two-part binary encoding of the full extended reply packet.
+func (ep *StatVFSExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ p := &sshfx.ExtendedReplyPacket{
+ Data: ep,
+ }
+ return p.MarshalPacket(reqid, b)
+}
+
+// UnmarshalPacketBody returns ep as a two-part binary encoding of the full extended reply packet.
+func (ep *StatVFSExtendedReplyPacket) UnmarshalPacketBody(buf *sshfx.Buffer) (err error) {
+ p := &sshfx.ExtendedReplyPacket{
+ Data: ep,
+ }
+ return p.UnmarshalPacketBody(buf)
+}
+
+// MarshalInto encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data.
+func (ep *StatVFSExtendedReplyPacket) MarshalInto(buf *sshfx.Buffer) {
+ buf.AppendUint64(ep.BlockSize)
+ buf.AppendUint64(ep.FragmentSize)
+ buf.AppendUint64(ep.Blocks)
+ buf.AppendUint64(ep.BlocksFree)
+ buf.AppendUint64(ep.BlocksAvail)
+ buf.AppendUint64(ep.Files)
+ buf.AppendUint64(ep.FilesFree)
+ buf.AppendUint64(ep.FilesAvail)
+ buf.AppendUint64(ep.FilesystemID)
+ buf.AppendUint64(ep.MountFlags)
+ buf.AppendUint64(ep.MaxNameLength)
+}
+
+// MarshalBinary encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data.
+//
+// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended reply packet.
+func (ep *StatVFSExtendedReplyPacket) MarshalBinary() ([]byte, error) {
+ size := 11 * 8 // 11 × uint64(various)
+
+ b := sshfx.NewBuffer(make([]byte, 0, size))
+ ep.MarshalInto(b)
+ return b.Bytes(), nil
+}
+
+// UnmarshalFrom decodes the fstatvfs@openssh.com extended reply packet-specific data into ep.
+func (ep *StatVFSExtendedReplyPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
+ if ep.BlockSize, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.FragmentSize, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.Blocks, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.BlocksFree, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.BlocksAvail, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.Files, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.FilesFree, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.FilesAvail, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.FilesystemID, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.MountFlags, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+ if ep.MaxNameLength, err = buf.ConsumeUint64(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the fstatvfs@openssh.com extended reply packet-specific data into ep.
+func (ep *StatVFSExtendedReplyPacket) UnmarshalBinary(data []byte) (err error) {
+ return ep.UnmarshalFrom(sshfx.NewBuffer(data))
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/statvfs_test.go b/internal/encoding/ssh/filexfer/openssh/statvfs_test.go
new file mode 100644
index 0000000..014aa63
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/statvfs_test.go
@@ -0,0 +1,239 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &StatVFSExtendedPacket{}
+
+func init() {
+ RegisterExtensionStatVFS()
+}
+
+func TestStatVFSExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ ep := &StatVFSExtendedPacket{
+ Path: path,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 36,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 19, 's', 't', 'a', 't', 'v', 'f', 's', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionStatVFS {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionStatVFS)
+ }
+
+ ep, ok := p.Data.(*StatVFSExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *StatVFSExtendedPacket", p.Data)
+ }
+
+ if ep.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", ep.Path, path)
+ }
+}
+
+var _ sshfx.PacketMarshaller = &FStatVFSExtendedPacket{}
+
+func init() {
+ RegisterExtensionFStatVFS()
+}
+
+func TestFStatVFSExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ ep := &FStatVFSExtendedPacket{
+ Path: path,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 37,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 20, 'f', 's', 't', 'a', 't', 'v', 'f', 's', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionFStatVFS {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionFStatVFS)
+ }
+
+ ep, ok := p.Data.(*FStatVFSExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *FStatVFSExtendedPacket", p.Data)
+ }
+
+ if ep.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", ep.Path, path)
+ }
+}
+
+var _ sshfx.Packet = &StatVFSExtendedReplyPacket{}
+
+func TestStatVFSExtendedReplyPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ const (
+ BlockSize = uint64(iota + 13)
+ FragmentSize
+ Blocks
+ BlocksFree
+ BlocksAvail
+ Files
+ FilesFree
+ FilesAvail
+ FilesystemID
+ MountFlags
+ MaxNameLength
+ )
+
+ ep := &StatVFSExtendedReplyPacket{
+ BlockSize: BlockSize,
+ FragmentSize: FragmentSize,
+ Blocks: Blocks,
+ BlocksFree: BlocksFree,
+ BlocksAvail: BlocksAvail,
+ Files: Files,
+ FilesFree: FilesFree,
+ FilesAvail: FilesAvail,
+ FilesystemID: FilesystemID,
+ MountFlags: MountFlags,
+ MaxNameLength: MaxNameLength,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 93,
+ 201,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 13,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 14,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 15,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 16,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 17,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 18,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 19,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 21,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 22,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 23,
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ *ep = StatVFSExtendedReplyPacket{}
+
+ p := sshfx.ExtendedReplyPacket{
+ Data: ep,
+ }
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ ep, ok := p.Data.(*StatVFSExtendedReplyPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *StatVFSExtendedReplyPacket", p.Data)
+ }
+
+ if ep.BlockSize != BlockSize {
+ t.Errorf("UnmarshalPacketBody(): BlockSize was %d, but expected %d", ep.BlockSize, BlockSize)
+ }
+
+ if ep.FragmentSize != FragmentSize {
+ t.Errorf("UnmarshalPacketBody(): FragmentSize was %d, but expected %d", ep.FragmentSize, FragmentSize)
+ }
+
+ if ep.Blocks != Blocks {
+ t.Errorf("UnmarshalPacketBody(): Blocks was %d, but expected %d", ep.Blocks, Blocks)
+ }
+
+ if ep.BlocksFree != BlocksFree {
+ t.Errorf("UnmarshalPacketBody(): BlocksFree was %d, but expected %d", ep.BlocksFree, BlocksFree)
+ }
+
+ if ep.BlocksAvail != BlocksAvail {
+ t.Errorf("UnmarshalPacketBody(): BlocksAvail was %d, but expected %d", ep.BlocksAvail, BlocksAvail)
+ }
+
+ if ep.Files != Files {
+ t.Errorf("UnmarshalPacketBody(): Files was %d, but expected %d", ep.Files, Files)
+ }
+
+ if ep.FilesFree != FilesFree {
+ t.Errorf("UnmarshalPacketBody(): FilesFree was %d, but expected %d", ep.FilesFree, FilesFree)
+ }
+
+ if ep.FilesAvail != FilesAvail {
+ t.Errorf("UnmarshalPacketBody(): FilesAvail was %d, but expected %d", ep.FilesAvail, FilesAvail)
+ }
+
+ if ep.FilesystemID != FilesystemID {
+ t.Errorf("UnmarshalPacketBody(): FilesystemID was %d, but expected %d", ep.FilesystemID, FilesystemID)
+ }
+
+ if ep.MountFlags != MountFlags {
+ t.Errorf("UnmarshalPacketBody(): MountFlags was %d, but expected %d", ep.MountFlags, MountFlags)
+ }
+
+ if ep.MaxNameLength != MaxNameLength {
+ t.Errorf("UnmarshalPacketBody(): MaxNameLength was %d, but expected %d", ep.MaxNameLength, MaxNameLength)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/packets.go b/internal/encoding/ssh/filexfer/packets.go
new file mode 100644
index 0000000..3f24e9c
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/packets.go
@@ -0,0 +1,323 @@
+package filexfer
+
+import (
+ "errors"
+ "fmt"
+ "io"
+)
+
+// smallBufferSize is an initial allocation minimal capacity.
+const smallBufferSize = 64
+
+func newPacketFromType(typ PacketType) (Packet, error) {
+ switch typ {
+ case PacketTypeOpen:
+ return new(OpenPacket), nil
+ case PacketTypeClose:
+ return new(ClosePacket), nil
+ case PacketTypeRead:
+ return new(ReadPacket), nil
+ case PacketTypeWrite:
+ return new(WritePacket), nil
+ case PacketTypeLStat:
+ return new(LStatPacket), nil
+ case PacketTypeFStat:
+ return new(FStatPacket), nil
+ case PacketTypeSetstat:
+ return new(SetstatPacket), nil
+ case PacketTypeFSetstat:
+ return new(FSetstatPacket), nil
+ case PacketTypeOpenDir:
+ return new(OpenDirPacket), nil
+ case PacketTypeReadDir:
+ return new(ReadDirPacket), nil
+ case PacketTypeRemove:
+ return new(RemovePacket), nil
+ case PacketTypeMkdir:
+ return new(MkdirPacket), nil
+ case PacketTypeRmdir:
+ return new(RmdirPacket), nil
+ case PacketTypeRealPath:
+ return new(RealPathPacket), nil
+ case PacketTypeStat:
+ return new(StatPacket), nil
+ case PacketTypeRename:
+ return new(RenamePacket), nil
+ case PacketTypeReadLink:
+ return new(ReadLinkPacket), nil
+ case PacketTypeSymlink:
+ return new(SymlinkPacket), nil
+ case PacketTypeExtended:
+ return new(ExtendedPacket), nil
+ default:
+ return nil, fmt.Errorf("unexpected request packet type: %v", typ)
+ }
+}
+
+// RawPacket implements the general packet format from draft-ietf-secsh-filexfer-02
+//
+// RawPacket is intended for use in clients receiving responses,
+// where a response will be expected to be of a limited number of types,
+// and unmarshaling unknown/unexpected response packets is unnecessary.
+//
+// For servers expecting to receive arbitrary request packet types,
+// use RequestPacket.
+//
+// Defined in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-3
+type RawPacket struct {
+ PacketType PacketType
+ RequestID uint32
+
+ Data Buffer
+}
+
+// Type returns the Type field defining the SSH_FXP_xy type for this packet.
+func (p *RawPacket) Type() PacketType {
+ return p.PacketType
+}
+
+// Reset clears the pointers and reference-semantic variables of RawPacket,
+// releasing underlying resources, and making them and the RawPacket suitable to be reused,
+// so long as no other references have been kept.
+func (p *RawPacket) Reset() {
+ p.Data = Buffer{}
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+//
+// The internal p.RequestID is overridden by the reqid argument.
+func (p *RawPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ buf = NewMarshalBuffer(0)
+ }
+
+ buf.StartPacket(p.PacketType, reqid)
+
+ return buf.Packet(p.Data.Bytes())
+}
+
+// MarshalBinary returns p as the binary encoding of p.
+//
+// This is a convenience implementation primarily intended for tests,
+// because it is inefficient with allocations.
+func (p *RawPacket) MarshalBinary() ([]byte, error) {
+ return ComposePacket(p.MarshalPacket(p.RequestID, nil))
+}
+
+// UnmarshalFrom decodes a RawPacket from the given Buffer into p.
+//
+// The Data field will alias the passed in Buffer,
+// so the buffer passed in should not be reused before RawPacket.Reset().
+func (p *RawPacket) UnmarshalFrom(buf *Buffer) error {
+ typ, err := buf.ConsumeUint8()
+ if err != nil {
+ return err
+ }
+
+ p.PacketType = PacketType(typ)
+
+ if p.RequestID, err = buf.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ p.Data = *buf
+ return nil
+}
+
+// UnmarshalBinary decodes a full raw packet out of the given data.
+// It is assumed that the uint32(length) has already been consumed to receive the data.
+//
+// This is a convenience implementation primarily intended for tests,
+// because this must clone the given data byte slice,
+// as Data is not allowed to alias any part of the data byte slice.
+func (p *RawPacket) UnmarshalBinary(data []byte) error {
+ clone := make([]byte, len(data))
+ n := copy(clone, data)
+ return p.UnmarshalFrom(NewBuffer(clone[:n]))
+}
+
+// readPacket reads a uint32 length-prefixed binary data packet from r.
+// using the given byte slice as a backing array.
+//
+// If the packet length read from r is bigger than maxPacketLength,
+// or greater than math.MaxInt32 on a 32-bit implementation,
+// then a `ErrLongPacket` error will be returned.
+//
+// If the given byte slice is insufficient to hold the packet,
+// then it will be extended to fill the packet size.
+func readPacket(r io.Reader, b []byte, maxPacketLength uint32) ([]byte, error) {
+ if cap(b) < 4 {
+ // We will need allocate our own buffer just for reading the packet length.
+
+ // However, we don’t really want to allocate an extremely narrow buffer (4-bytes),
+ // and cause unnecessary allocation churn from both length reads and small packet reads,
+ // so we use smallBufferSize from the bytes package as a reasonable guess.
+
+ // But if callers really do want to force narrow throw-away allocation of every packet body,
+ // they can do so with a buffer of capacity 4.
+ b = make([]byte, smallBufferSize)
+ }
+
+ if _, err := io.ReadFull(r, b[:4]); err != nil {
+ return nil, err
+ }
+
+ length := unmarshalUint32(b)
+ if int(length) < 5 {
+ // Must have at least uint8(type) and uint32(request-id)
+
+ if int(length) < 0 {
+ // Only possible when strconv.IntSize == 32,
+ // the packet length is longer than math.MaxInt32,
+ // and thus longer than any possible slice.
+ return nil, ErrLongPacket
+ }
+
+ return nil, ErrShortPacket
+ }
+ if length > maxPacketLength {
+ return nil, ErrLongPacket
+ }
+
+ if int(length) > cap(b) {
+ // We know int(length) must be positive, because of tests above.
+ b = make([]byte, length)
+ }
+
+ n, err := io.ReadFull(r, b[:length])
+ return b[:n], err
+}
+
+// ReadFrom provides a simple functional packet reader,
+// using the given byte slice as a backing array.
+//
+// To protect against potential denial of service attacks,
+// if the read packet length is longer than maxPacketLength,
+// then no packet data will be read, and ErrLongPacket will be returned.
+// (On 32-bit int architectures, all packets >= 2^31 in length
+// will return ErrLongPacket regardless of maxPacketLength.)
+//
+// If the read packet length is longer than cap(b),
+// then a throw-away slice will allocated to meet the exact packet length.
+// This can be used to limit the length of reused buffers,
+// while still allowing reception of occasional large packets.
+//
+// The Data field may alias the passed in byte slice,
+// so the byte slice passed in should not be reused before RawPacket.Reset().
+func (p *RawPacket) ReadFrom(r io.Reader, b []byte, maxPacketLength uint32) error {
+ b, err := readPacket(r, b, maxPacketLength)
+ if err != nil {
+ return err
+ }
+
+ return p.UnmarshalFrom(NewBuffer(b))
+}
+
+// RequestPacket implements the general packet format from draft-ietf-secsh-filexfer-02
+// but also automatically decode/encodes valid request packets (2 < type < 100 || type == 200).
+//
+// RequestPacket is intended for use in servers receiving requests,
+// where any arbitrary request may be received, and so decoding them automatically
+// is useful.
+//
+// For clients expecting to receive specific response packet types,
+// where automatic unmarshaling of the packet body does not make sense,
+// use RawPacket.
+//
+// Defined in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-3
+type RequestPacket struct {
+ RequestID uint32
+
+ Request Packet
+}
+
+// Type returns the SSH_FXP_xy value associated with the underlying packet.
+func (p *RequestPacket) Type() PacketType {
+ return p.Request.Type()
+}
+
+// Reset clears the pointers and reference-semantic variables in RequestPacket,
+// releasing underlying resources, and making them and the RequestPacket suitable to be reused,
+// so long as no other references have been kept.
+func (p *RequestPacket) Reset() {
+ p.Request = nil
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+//
+// The internal p.RequestID is overridden by the reqid argument.
+func (p *RequestPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ if p.Request == nil {
+ return nil, nil, errors.New("empty request packet")
+ }
+
+ return p.Request.MarshalPacket(reqid, b)
+}
+
+// MarshalBinary returns p as the binary encoding of p.
+//
+// This is a convenience implementation primarily intended for tests,
+// because it is inefficient with allocations.
+func (p *RequestPacket) MarshalBinary() ([]byte, error) {
+ return ComposePacket(p.MarshalPacket(p.RequestID, nil))
+}
+
+// UnmarshalFrom decodes a RequestPacket from the given Buffer into p.
+//
+// The Request field may alias the passed in Buffer, (e.g. SSH_FXP_WRITE),
+// so the buffer passed in should not be reused before RequestPacket.Reset().
+func (p *RequestPacket) UnmarshalFrom(buf *Buffer) error {
+ typ, err := buf.ConsumeUint8()
+ if err != nil {
+ return err
+ }
+
+ p.Request, err = newPacketFromType(PacketType(typ))
+ if err != nil {
+ return err
+ }
+
+ if p.RequestID, err = buf.ConsumeUint32(); err != nil {
+ return err
+ }
+
+ return p.Request.UnmarshalPacketBody(buf)
+}
+
+// UnmarshalBinary decodes a full request packet out of the given data.
+// It is assumed that the uint32(length) has already been consumed to receive the data.
+//
+// This is a convenience implementation primarily intended for tests,
+// because this must clone the given data byte slice,
+// as Request is not allowed to alias any part of the data byte slice.
+func (p *RequestPacket) UnmarshalBinary(data []byte) error {
+ clone := make([]byte, len(data))
+ n := copy(clone, data)
+ return p.UnmarshalFrom(NewBuffer(clone[:n]))
+}
+
+// ReadFrom provides a simple functional packet reader,
+// using the given byte slice as a backing array.
+//
+// To protect against potential denial of service attacks,
+// if the read packet length is longer than maxPacketLength,
+// then no packet data will be read, and ErrLongPacket will be returned.
+// (On 32-bit int architectures, all packets >= 2^31 in length
+// will return ErrLongPacket regardless of maxPacketLength.)
+//
+// If the read packet length is longer than cap(b),
+// then a throw-away slice will allocated to meet the exact packet length.
+// This can be used to limit the length of reused buffers,
+// while still allowing reception of occasional large packets.
+//
+// The Request field may alias the passed in byte slice,
+// so the byte slice passed in should not be reused before RawPacket.Reset().
+func (p *RequestPacket) ReadFrom(r io.Reader, b []byte, maxPacketLength uint32) error {
+ b, err := readPacket(r, b, maxPacketLength)
+ if err != nil {
+ return err
+ }
+
+ return p.UnmarshalFrom(NewBuffer(b))
+}
diff --git a/internal/encoding/ssh/filexfer/packets_test.go b/internal/encoding/ssh/filexfer/packets_test.go
new file mode 100644
index 0000000..1600920
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/packets_test.go
@@ -0,0 +1,132 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestRawPacket(t *testing.T) {
+ const (
+ id = 42
+ errMsg = "eof"
+ langTag = "en"
+ )
+
+ p := &RawPacket{
+ PacketType: PacketTypeStatus,
+ RequestID: id,
+ Data: Buffer{
+ b: []byte{
+ 0x00, 0x00, 0x00, 0x01,
+ 0x00, 0x00, 0x00, 0x03, 'e', 'o', 'f',
+ 0x00, 0x00, 0x00, 0x02, 'e', 'n',
+ },
+ },
+ }
+
+ buf, err := p.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 22,
+ 101,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 0x01,
+ 0x00, 0x00, 0x00, 3, 'e', 'o', 'f',
+ 0x00, 0x00, 0x00, 2, 'e', 'n',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Errorf("RawPacket.MarshalBinary() = %X, but wanted %X", buf, want)
+ }
+
+ *p = RawPacket{}
+
+ if err := p.ReadFrom(bytes.NewReader(buf), nil, DefaultMaxPacketLength); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.PacketType != PacketTypeStatus {
+ t.Errorf("RawPacket.UnmarshalBinary(): Type was %v, but expected %v", p.PacketType, PacketTypeStat)
+ }
+
+ if p.RequestID != uint32(id) {
+ t.Errorf("RawPacket.UnmarshalBinary(): RequestID was %d, but expected %d", p.RequestID, id)
+ }
+
+ want = []byte{
+ 0x00, 0x00, 0x00, 0x01,
+ 0x00, 0x00, 0x00, 3, 'e', 'o', 'f',
+ 0x00, 0x00, 0x00, 2, 'e', 'n',
+ }
+
+ if !bytes.Equal(p.Data.Bytes(), want) {
+ t.Fatalf("RawPacket.UnmarshalBinary(): Data was %X, but expected %X", p.Data, want)
+ }
+
+ var resp StatusPacket
+ resp.UnmarshalPacketBody(&p.Data)
+
+ if resp.StatusCode != StatusEOF {
+ t.Errorf("UnmarshalPacketBody(): StatusCode was %v, but expected %v", resp.StatusCode, StatusEOF)
+ }
+
+ if resp.ErrorMessage != errMsg {
+ t.Errorf("UnmarshalPacketBody(): ErrorMessage was %q, but expected %q", resp.ErrorMessage, errMsg)
+ }
+
+ if resp.LanguageTag != langTag {
+ t.Errorf("UnmarshalPacketBody(): LanguageTag was %q, but expected %q", resp.LanguageTag, langTag)
+ }
+}
+
+func TestRequestPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "foo"
+ )
+
+ p := &RequestPacket{
+ RequestID: id,
+ Request: &StatPacket{
+ Path: path,
+ },
+ }
+
+ buf, err := p.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 12,
+ 17,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 3, 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Errorf("RequestPacket.MarshalBinary() = %X, but wanted %X", buf, want)
+ }
+
+ *p = RequestPacket{}
+
+ if err := p.ReadFrom(bytes.NewReader(buf), nil, DefaultMaxPacketLength); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.RequestID != uint32(id) {
+ t.Errorf("RequestPacket.UnmarshalBinary(): RequestID was %d, but expected %d", p.RequestID, id)
+ }
+
+ req, ok := p.Request.(*StatPacket)
+ if !ok {
+ t.Fatalf("unexpected Request type was %T, but expected %T", p.Request, req)
+ }
+
+ if req.Path != path {
+ t.Errorf("RequestPacket.UnmarshalBinary(): Request.Path was %q, but expected %q", req.Path, path)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/path_packets.go b/internal/encoding/ssh/filexfer/path_packets.go
new file mode 100644
index 0000000..e6f692d
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/path_packets.go
@@ -0,0 +1,368 @@
+package filexfer
+
+// LStatPacket defines the SSH_FXP_LSTAT packet.
+type LStatPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *LStatPacket) Type() PacketType {
+ return PacketTypeLStat
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *LStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) // string(path)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeLStat, reqid)
+ buf.AppendString(p.Path)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *LStatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// SetstatPacket defines the SSH_FXP_SETSTAT packet.
+type SetstatPacket struct {
+ Path string
+ Attrs Attributes
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *SetstatPacket) Type() PacketType {
+ return PacketTypeSetstat
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *SetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeSetstat, reqid)
+ buf.AppendString(p.Path)
+
+ p.Attrs.MarshalInto(buf)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *SetstatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return p.Attrs.UnmarshalFrom(buf)
+}
+
+// RemovePacket defines the SSH_FXP_REMOVE packet.
+type RemovePacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RemovePacket) Type() PacketType {
+ return PacketTypeRemove
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *RemovePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) // string(path)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeRemove, reqid)
+ buf.AppendString(p.Path)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *RemovePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// MkdirPacket defines the SSH_FXP_MKDIR packet.
+type MkdirPacket struct {
+ Path string
+ Attrs Attributes
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *MkdirPacket) Type() PacketType {
+ return PacketTypeMkdir
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *MkdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeMkdir, reqid)
+ buf.AppendString(p.Path)
+
+ p.Attrs.MarshalInto(buf)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *MkdirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return p.Attrs.UnmarshalFrom(buf)
+}
+
+// RmdirPacket defines the SSH_FXP_RMDIR packet.
+type RmdirPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RmdirPacket) Type() PacketType {
+ return PacketTypeRmdir
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *RmdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) // string(path)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeRmdir, reqid)
+ buf.AppendString(p.Path)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *RmdirPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// RealPathPacket defines the SSH_FXP_REALPATH packet.
+type RealPathPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RealPathPacket) Type() PacketType {
+ return PacketTypeRealPath
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *RealPathPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) // string(path)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeRealPath, reqid)
+ buf.AppendString(p.Path)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *RealPathPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// StatPacket defines the SSH_FXP_STAT packet.
+type StatPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *StatPacket) Type() PacketType {
+ return PacketTypeStat
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *StatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) // string(path)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeStat, reqid)
+ buf.AppendString(p.Path)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *StatPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// RenamePacket defines the SSH_FXP_RENAME packet.
+type RenamePacket struct {
+ OldPath string
+ NewPath string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RenamePacket) Type() PacketType {
+ return PacketTypeRename
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *RenamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ // string(oldpath) + string(newpath)
+ size := 4 + len(p.OldPath) + 4 + len(p.NewPath)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeRename, reqid)
+ buf.AppendString(p.OldPath)
+ buf.AppendString(p.NewPath)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *RenamePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.OldPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if p.NewPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// ReadLinkPacket defines the SSH_FXP_READLINK packet.
+type ReadLinkPacket struct {
+ Path string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ReadLinkPacket) Type() PacketType {
+ return PacketTypeReadLink
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *ReadLinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Path) // string(path)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeReadLink, reqid)
+ buf.AppendString(p.Path)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *ReadLinkPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Path, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// SymlinkPacket defines the SSH_FXP_SYMLINK packet.
+//
+// The order of the arguments to the SSH_FXP_SYMLINK method was inadvertently reversed.
+// Unfortunately, the reversal was not noticed until the server was widely deployed.
+// Covered in Section 3.1 of https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
+type SymlinkPacket struct {
+ LinkPath string
+ TargetPath string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *SymlinkPacket) Type() PacketType {
+ return PacketTypeSymlink
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *SymlinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ // string(targetpath) + string(linkpath)
+ size := 4 + len(p.TargetPath) + 4 + len(p.LinkPath)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeSymlink, reqid)
+
+ // Arguments were inadvertently reversed.
+ buf.AppendString(p.TargetPath)
+ buf.AppendString(p.LinkPath)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *SymlinkPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ // Arguments were inadvertently reversed.
+ if p.TargetPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if p.LinkPath, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/encoding/ssh/filexfer/path_packets_test.go b/internal/encoding/ssh/filexfer/path_packets_test.go
new file mode 100644
index 0000000..4cff582
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/path_packets_test.go
@@ -0,0 +1,450 @@
+package filexfer
+
+import (
+ "bytes"
+ "testing"
+)
+
+var _ Packet = &LStatPacket{}
+
+func TestLStatPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ p := &LStatPacket{
+ Path: path,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 7,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = LStatPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+}
+
+var _ Packet = &SetstatPacket{}
+
+func TestSetstatPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ perms = 0x87654321
+ )
+
+ p := &SetstatPacket{
+ Path: "/foo",
+ Attrs: Attributes{
+ Flags: AttrPermissions,
+ Permissions: perms,
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 21,
+ 9,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 0x04,
+ 0x87, 0x65, 0x43, 0x21,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = SetstatPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+
+ if p.Attrs.Flags != AttrPermissions {
+ t.Errorf("UnmarshalPacketBody(): Attrs.Flags was %#x, but expected %#x", p.Attrs.Flags, AttrPermissions)
+ }
+
+ if p.Attrs.Permissions != perms {
+ t.Errorf("UnmarshalPacketBody(): Attrs.Permissions was %#v, but expected %#v", p.Attrs.Permissions, perms)
+ }
+}
+
+var _ Packet = &RemovePacket{}
+
+func TestRemovePacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ p := &RemovePacket{
+ Path: path,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 13,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = RemovePacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+}
+
+var _ Packet = &MkdirPacket{}
+
+func TestMkdirPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ perms = 0x87654321
+ )
+
+ p := &MkdirPacket{
+ Path: "/foo",
+ Attrs: Attributes{
+ Flags: AttrPermissions,
+ Permissions: perms,
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 21,
+ 14,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 0x04,
+ 0x87, 0x65, 0x43, 0x21,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = MkdirPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+
+ if p.Attrs.Flags != AttrPermissions {
+ t.Errorf("UnmarshalPacketBody(): Attrs.Flags was %#x, but expected %#x", p.Attrs.Flags, AttrPermissions)
+ }
+
+ if p.Attrs.Permissions != perms {
+ t.Errorf("UnmarshalPacketBody(): Attrs.Permissions was %#v, but expected %#v", p.Attrs.Permissions, perms)
+ }
+}
+
+var _ Packet = &RmdirPacket{}
+
+func TestRmdirPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ p := &RmdirPacket{
+ Path: path,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 15,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = RmdirPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+}
+
+var _ Packet = &RealPathPacket{}
+
+func TestRealPathPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ p := &RealPathPacket{
+ Path: path,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 16,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = RealPathPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+}
+
+var _ Packet = &StatPacket{}
+
+func TestStatPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ p := &StatPacket{
+ Path: path,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 17,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = StatPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+}
+
+var _ Packet = &RenamePacket{}
+
+func TestRenamePacket(t *testing.T) {
+ const (
+ id = 42
+ oldpath = "/foo"
+ newpath = "/bar"
+ )
+
+ p := &RenamePacket{
+ OldPath: oldpath,
+ NewPath: newpath,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 21,
+ 18,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = RenamePacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.OldPath != oldpath {
+ t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", p.OldPath, oldpath)
+ }
+
+ if p.NewPath != newpath {
+ t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", p.NewPath, newpath)
+ }
+}
+
+var _ Packet = &ReadLinkPacket{}
+
+func TestReadLinkPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ p := &ReadLinkPacket{
+ Path: path,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 19,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = ReadLinkPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path)
+ }
+}
+
+var _ Packet = &SymlinkPacket{}
+
+func TestSymlinkPacket(t *testing.T) {
+ const (
+ id = 42
+ linkpath = "/foo"
+ targetpath = "/bar"
+ )
+
+ p := &SymlinkPacket{
+ LinkPath: linkpath,
+ TargetPath: targetpath,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 21,
+ 20,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r', // Arguments were inadvertently reversed.
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = SymlinkPacket{}
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.LinkPath != linkpath {
+ t.Errorf("UnmarshalPacketBody(): LinkPath was %q, but expected %q", p.LinkPath, linkpath)
+ }
+
+ if p.TargetPath != targetpath {
+ t.Errorf("UnmarshalPacketBody(): TargetPath was %q, but expected %q", p.TargetPath, targetpath)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/permissions.go b/internal/encoding/ssh/filexfer/permissions.go
new file mode 100644
index 0000000..2fe63d5
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/permissions.go
@@ -0,0 +1,114 @@
+package filexfer
+
+// FileMode represents a file’s mode and permission bits.
+// The bits are defined according to POSIX standards,
+// and may not apply to the OS being built for.
+type FileMode uint32
+
+// Permission flags, defined here to avoid potential inconsistencies in individual OS implementations.
+const (
+ ModePerm FileMode = 0o0777 // S_IRWXU | S_IRWXG | S_IRWXO
+ ModeUserRead FileMode = 0o0400 // S_IRUSR
+ ModeUserWrite FileMode = 0o0200 // S_IWUSR
+ ModeUserExec FileMode = 0o0100 // S_IXUSR
+ ModeGroupRead FileMode = 0o0040 // S_IRGRP
+ ModeGroupWrite FileMode = 0o0020 // S_IWGRP
+ ModeGroupExec FileMode = 0o0010 // S_IXGRP
+ ModeOtherRead FileMode = 0o0004 // S_IROTH
+ ModeOtherWrite FileMode = 0o0002 // S_IWOTH
+ ModeOtherExec FileMode = 0o0001 // S_IXOTH
+
+ ModeSetUID FileMode = 0o4000 // S_ISUID
+ ModeSetGID FileMode = 0o2000 // S_ISGID
+ ModeSticky FileMode = 0o1000 // S_ISVTX
+
+ ModeType FileMode = 0xF000 // S_IFMT
+ ModeNamedPipe FileMode = 0x1000 // S_IFIFO
+ ModeCharDevice FileMode = 0x2000 // S_IFCHR
+ ModeDir FileMode = 0x4000 // S_IFDIR
+ ModeDevice FileMode = 0x6000 // S_IFBLK
+ ModeRegular FileMode = 0x8000 // S_IFREG
+ ModeSymlink FileMode = 0xA000 // S_IFLNK
+ ModeSocket FileMode = 0xC000 // S_IFSOCK
+)
+
+// IsDir reports whether m describes a directory.
+// That is, it tests for m.Type() == ModeDir.
+func (m FileMode) IsDir() bool {
+ return (m & ModeType) == ModeDir
+}
+
+// IsRegular reports whether m describes a regular file.
+// That is, it tests for m.Type() == ModeRegular
+func (m FileMode) IsRegular() bool {
+ return (m & ModeType) == ModeRegular
+}
+
+// Perm returns the POSIX permission bits in m (m & ModePerm).
+func (m FileMode) Perm() FileMode {
+ return (m & ModePerm)
+}
+
+// Type returns the type bits in m (m & ModeType).
+func (m FileMode) Type() FileMode {
+ return (m & ModeType)
+}
+
+// String returns a `-rwxrwxrwx` style string representing the `ls -l` POSIX permissions string.
+func (m FileMode) String() string {
+ var buf [10]byte
+
+ switch m.Type() {
+ case ModeRegular:
+ buf[0] = '-'
+ case ModeDir:
+ buf[0] = 'd'
+ case ModeSymlink:
+ buf[0] = 'l'
+ case ModeDevice:
+ buf[0] = 'b'
+ case ModeCharDevice:
+ buf[0] = 'c'
+ case ModeNamedPipe:
+ buf[0] = 'p'
+ case ModeSocket:
+ buf[0] = 's'
+ default:
+ buf[0] = '?'
+ }
+
+ const rwx = "rwxrwxrwx"
+ for i, c := range rwx {
+ if m&(1<<uint(9-1-i)) != 0 {
+ buf[i+1] = byte(c)
+ } else {
+ buf[i+1] = '-'
+ }
+ }
+
+ if m&ModeSetUID != 0 {
+ if buf[3] == 'x' {
+ buf[3] = 's'
+ } else {
+ buf[3] = 'S'
+ }
+ }
+
+ if m&ModeSetGID != 0 {
+ if buf[6] == 'x' {
+ buf[6] = 's'
+ } else {
+ buf[6] = 'S'
+ }
+ }
+
+ if m&ModeSticky != 0 {
+ if buf[9] == 'x' {
+ buf[9] = 't'
+ } else {
+ buf[9] = 'T'
+ }
+ }
+
+ return string(buf[:])
+}
diff --git a/internal/encoding/ssh/filexfer/response_packets.go b/internal/encoding/ssh/filexfer/response_packets.go
new file mode 100644
index 0000000..7a9b3ea
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/response_packets.go
@@ -0,0 +1,243 @@
+package filexfer
+
+import (
+ "fmt"
+)
+
+// StatusPacket defines the SSH_FXP_STATUS packet.
+//
+// Specified in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-7
+type StatusPacket struct {
+ StatusCode Status
+ ErrorMessage string
+ LanguageTag string
+}
+
+// Error makes StatusPacket an error type.
+func (p *StatusPacket) Error() string {
+ if p.ErrorMessage == "" {
+ return "sftp: " + p.StatusCode.String()
+ }
+
+ return fmt.Sprintf("sftp: %q (%s)", p.ErrorMessage, p.StatusCode)
+}
+
+// Is returns true if target is a StatusPacket with the same StatusCode,
+// or target is a Status code which is the same as SatusCode.
+func (p *StatusPacket) Is(target error) bool {
+ if target, ok := target.(*StatusPacket); ok {
+ return p.StatusCode == target.StatusCode
+ }
+
+ return p.StatusCode == target
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *StatusPacket) Type() PacketType {
+ return PacketTypeStatus
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *StatusPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ // uint32(error/status code) + string(error message) + string(language tag)
+ size := 4 + 4 + len(p.ErrorMessage) + 4 + len(p.LanguageTag)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeStatus, reqid)
+ buf.AppendUint32(uint32(p.StatusCode))
+ buf.AppendString(p.ErrorMessage)
+ buf.AppendString(p.LanguageTag)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *StatusPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ statusCode, err := buf.ConsumeUint32()
+ if err != nil {
+ return err
+ }
+ p.StatusCode = Status(statusCode)
+
+ if p.ErrorMessage, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ if p.LanguageTag, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// HandlePacket defines the SSH_FXP_HANDLE packet.
+type HandlePacket struct {
+ Handle string
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *HandlePacket) Type() PacketType {
+ return PacketTypeHandle
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *HandlePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 + len(p.Handle) // string(handle)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeHandle, reqid)
+ buf.AppendString(p.Handle)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *HandlePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ if p.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// DataPacket defines the SSH_FXP_DATA packet.
+type DataPacket struct {
+ Data []byte
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *DataPacket) Type() PacketType {
+ return PacketTypeData
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *DataPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 // uint32(len(data)); data content in payload
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeData, reqid)
+ buf.AppendUint32(uint32(len(p.Data)))
+
+ return buf.Packet(p.Data)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+//
+// If p.Data is already populated, and of sufficient length to hold the data,
+// then this will copy the data into that byte slice.
+//
+// If p.Data has a length insufficient to hold the data,
+// then this will make a new slice of sufficient length, and copy the data into that.
+//
+// This means this _does not_ alias any of the data buffer that is passed in.
+func (p *DataPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ data, err := buf.ConsumeByteSlice()
+ if err != nil {
+ return err
+ }
+
+ if len(p.Data) < len(data) {
+ p.Data = make([]byte, len(data))
+ }
+
+ n := copy(p.Data, data)
+ p.Data = p.Data[:n]
+ return nil
+}
+
+// NamePacket defines the SSH_FXP_NAME packet.
+type NamePacket struct {
+ Entries []*NameEntry
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *NamePacket) Type() PacketType {
+ return PacketTypeName
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *NamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := 4 // uint32(len(entries))
+
+ for _, e := range p.Entries {
+ size += e.Len()
+ }
+
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeName, reqid)
+ buf.AppendUint32(uint32(len(p.Entries)))
+
+ for _, e := range p.Entries {
+ e.MarshalInto(buf)
+ }
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *NamePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ count, err := buf.ConsumeUint32()
+ if err != nil {
+ return err
+ }
+
+ p.Entries = make([]*NameEntry, 0, count)
+
+ for i := uint32(0); i < count; i++ {
+ var e NameEntry
+ if err := e.UnmarshalFrom(buf); err != nil {
+ return err
+ }
+
+ p.Entries = append(p.Entries, &e)
+ }
+
+ return nil
+}
+
+// AttrsPacket defines the SSH_FXP_ATTRS packet.
+type AttrsPacket struct {
+ Attrs Attributes
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *AttrsPacket) Type() PacketType {
+ return PacketTypeAttrs
+}
+
+// MarshalPacket returns p as a two-part binary encoding of p.
+func (p *AttrsPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ buf := NewBuffer(b)
+ if buf.Cap() < 9 {
+ size := p.Attrs.Len() // ATTRS(attrs)
+ buf = NewMarshalBuffer(size)
+ }
+
+ buf.StartPacket(PacketTypeAttrs, reqid)
+ p.Attrs.MarshalInto(buf)
+
+ return buf.Packet(payload)
+}
+
+// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
+// It is assumed that the uint32(request-id) has already been consumed.
+func (p *AttrsPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
+ return p.Attrs.UnmarshalFrom(buf)
+}
diff --git a/internal/encoding/ssh/filexfer/response_packets_test.go b/internal/encoding/ssh/filexfer/response_packets_test.go
new file mode 100644
index 0000000..9468665
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/response_packets_test.go
@@ -0,0 +1,296 @@
+package filexfer
+
+import (
+ "bytes"
+ "errors"
+ "testing"
+)
+
+func TestStatusPacketIs(t *testing.T) {
+ status := &StatusPacket{
+ StatusCode: StatusFailure,
+ ErrorMessage: "error message",
+ LanguageTag: "language tag",
+ }
+
+ if !errors.Is(status, StatusFailure) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) != true")
+ }
+ if !errors.Is(status, &StatusPacket{StatusCode: StatusFailure}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) != true")
+ }
+ if errors.Is(status, StatusOK) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) == true")
+ }
+ if errors.Is(status, &StatusPacket{StatusCode: StatusOK}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) == true")
+ }
+}
+
+var _ Packet = &StatusPacket{}
+
+func TestStatusPacket(t *testing.T) {
+ const (
+ id = 42
+ statusCode = StatusBadMessage
+ errorMessage = "foo"
+ languageTag = "x-example"
+ )
+
+ p := &StatusPacket{
+ StatusCode: statusCode,
+ ErrorMessage: errorMessage,
+ LanguageTag: languageTag,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 29,
+ 101,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 5,
+ 0x00, 0x00, 0x00, 3, 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 9, 'x', '-', 'e', 'x', 'a', 'm', 'p', 'l', 'e',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = StatusPacket{}
+
+ // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.StatusCode != statusCode {
+ t.Errorf("UnmarshalBinary(): StatusCode was %v, but expected %v", p.StatusCode, statusCode)
+ }
+
+ if p.ErrorMessage != errorMessage {
+ t.Errorf("UnmarshalBinary(): ErrorMessage was %q, but expected %q", p.ErrorMessage, errorMessage)
+ }
+
+ if p.LanguageTag != languageTag {
+ t.Errorf("UnmarshalBinary(): LanguageTag was %q, but expected %q", p.LanguageTag, languageTag)
+ }
+}
+
+var _ Packet = &HandlePacket{}
+
+func TestHandlePacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ )
+
+ p := &HandlePacket{
+ Handle: "somehandle",
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 19,
+ 102,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = HandlePacket{}
+
+ // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Handle != handle {
+ t.Errorf("UnmarshalBinary(): Handle was %q, but expected %q", p.Handle, handle)
+ }
+}
+
+var _ Packet = &DataPacket{}
+
+func TestDataPacket(t *testing.T) {
+ const (
+ id = 42
+ )
+
+ var payload = []byte(`foobar`)
+
+ p := &DataPacket{
+ Data: payload,
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 15,
+ 103,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 6, 'f', 'o', 'o', 'b', 'a', 'r',
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = DataPacket{}
+
+ // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if !bytes.Equal(p.Data, payload) {
+ t.Errorf("UnmarshalBinary(): Data was %X, but expected %X", p.Data, payload)
+ }
+}
+
+var _ Packet = &NamePacket{}
+
+func TestNamePacket(t *testing.T) {
+ const (
+ id = 42
+ filename = "foo"
+ longname = "bar"
+ perms = 0x87654300
+ )
+
+ p := &NamePacket{
+ Entries: []*NameEntry{
+ &NameEntry{
+ Filename: filename + "1",
+ Longname: longname + "1",
+ Attrs: Attributes{
+ Flags: AttrPermissions | (1 << 8),
+ Permissions: perms | 1,
+ },
+ },
+ &NameEntry{
+ Filename: filename + "2",
+ Longname: longname + "2",
+ Attrs: Attributes{
+ Flags: AttrPermissions | (2 << 8),
+ Permissions: perms | 2,
+ },
+ },
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 57,
+ 104,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 0x02,
+ 0x00, 0x00, 0x00, 4, 'f', 'o', 'o', '1',
+ 0x00, 0x00, 0x00, 4, 'b', 'a', 'r', '1',
+ 0x00, 0x00, 0x01, 0x04,
+ 0x87, 0x65, 0x43, 0x01,
+ 0x00, 0x00, 0x00, 4, 'f', 'o', 'o', '2',
+ 0x00, 0x00, 0x00, 4, 'b', 'a', 'r', '2',
+ 0x00, 0x00, 0x02, 0x04,
+ 0x87, 0x65, 0x43, 0x02,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = NamePacket{}
+
+ // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if count := len(p.Entries); count != 2 {
+ t.Fatalf("UnmarshalBinary(): len(NameEntries) was %d, but expected %d", count, 2)
+ }
+
+ for i, e := range p.Entries {
+ if got, want := e.Filename, filename+string('1'+rune(i)); got != want {
+ t.Errorf("UnmarshalBinary(): Entries[%d].Filename was %q, but expected %q", i, got, want)
+ }
+
+ if got, want := e.Longname, longname+string('1'+rune(i)); got != want {
+ t.Errorf("UnmarshalBinary(): Entries[%d].Longname was %q, but expected %q", i, got, want)
+ }
+
+ if got, want := e.Attrs.Flags, AttrPermissions|((i+1)<<8); got != uint32(want) {
+ t.Errorf("UnmarshalBinary(): Entries[%d].Attrs.Flags was %#x, but expected %#x", i, got, want)
+ }
+
+ if got, want := e.Attrs.Permissions, FileMode(perms|(i+1)); got != want {
+ t.Errorf("UnmarshalBinary(): Entries[%d].Attrs.Permissions was %#v, but expected %#v", i, got, want)
+ }
+ }
+}
+
+var _ Packet = &AttrsPacket{}
+
+func TestAttrsPacket(t *testing.T) {
+ const (
+ id = 42
+ perms = 0x87654321
+ )
+
+ p := &AttrsPacket{
+ Attrs: Attributes{
+ Flags: AttrPermissions,
+ Permissions: perms,
+ },
+ }
+
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 13,
+ 105,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 0x04,
+ 0x87, 0x65, 0x43, 0x21,
+ }
+
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
+ }
+
+ *p = AttrsPacket{}
+
+ // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.Attrs.Flags != AttrPermissions {
+ t.Errorf("UnmarshalBinary(): Attrs.Flags was %#x, but expected %#x", p.Attrs.Flags, AttrPermissions)
+ }
+
+ if p.Attrs.Permissions != perms {
+ t.Errorf("UnmarshalBinary(): Attrs.Permissions was %#v, but expected %#v", p.Attrs.Permissions, perms)
+ }
+}
diff --git a/main.go b/main.go
index cd68a46..accc002 100644
--- a/main.go
+++ b/main.go
@@ -1,6 +1,16 @@
package main
import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "git.deuxfleurs.fr/Deuxfleurs/bagage/sftp"
+ "git.deuxfleurs.fr/Deuxfleurs/bagage/s3"
+ "github.com/minio/minio-go/v7/pkg/credentials"
+ "github.com/minio/minio-go/v7"
+ "golang.org/x/crypto/ssh"
"log"
"net/http"
)
@@ -11,6 +21,149 @@ func main() {
log.Println(config)
+ done := make(chan error)
+
+ go httpServer(config, done)
+ go sshServer(config, done)
+
+ err := <- done
+ if err != nil {
+ log.Fatalf("A component failed: %v", err)
+ }
+}
+
+type s3creds struct {
+ accessKey string
+ secretKey string
+}
+var keychain map[string]s3creds
+
+func sshServer(dconfig* Config, done chan error) {
+ keychain = make(map[string]s3creds)
+
+ config := &ssh.ServerConfig{
+ PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
+ log.Printf("Login: %s\n", c.User())
+ access_key, secret_key, err := LdapGetS3(dconfig, c.User(), string(pass))
+ if err == nil {
+ keychain[c.User()] = s3creds{ access_key, secret_key }
+ }
+ return nil, err
+ },
+ }
+
+ privateBytes, err := ioutil.ReadFile(dconfig.SSHKey)
+ if err != nil {
+ log.Fatal("Failed to load private key", err)
+ }
+
+ private, err := ssh.ParsePrivateKey(privateBytes)
+ if err != nil {
+ log.Fatal("Failed to parse private key", err)
+ }
+
+ config.AddHostKey(private)
+
+ // Once a ServerConfig has been configured, connections can be
+ // accepted.
+ listener, err := net.Listen("tcp", "0.0.0.0:2222")
+ if err != nil {
+ log.Fatal("failed to listen for connection", err)
+ }
+ log.Printf("Listening on %v\n", listener.Addr())
+
+ for {
+ nConn, err := listener.Accept()
+ if err != nil {
+ log.Printf("failed to accept incoming connection: ", err)
+ continue
+ }
+ go handleSSHConn(nConn, dconfig, config)
+ }
+}
+
+func handleSSHConn(nConn net.Conn, dconfig* Config, config *ssh.ServerConfig) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ defer nConn.Close()
+
+ // Before use, a handshake must be performed on the incoming
+ // net.Conn.
+ serverConn, chans, reqs, err := ssh.NewServerConn(nConn, config)
+ if err != nil {
+ log.Printf("failed to handshake: ", err)
+ }
+ defer serverConn.Conn.Close()
+ user := serverConn.Conn.User()
+ log.Printf("SSH connection established for %v\n", user)
+
+ // The incoming Request channel must be serviced.
+ go ssh.DiscardRequests(reqs)
+
+ // Service the incoming Channel channel.
+ for newChannel := range chans {
+ // Channels have a type, depending on the application level
+ // protocol intended. In the case of an SFTP session, this is "subsystem"
+ // with a payload string of "<length=4>sftp"
+ log.Printf("Incoming channel: %s\n", newChannel.ChannelType())
+ if newChannel.ChannelType() != "session" {
+ newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
+ log.Printf("Unknown channel type: %s\n", newChannel.ChannelType())
+ continue
+ }
+
+ channel, requests, err := newChannel.Accept()
+ if err != nil {
+ log.Print("could not accept channel.", err)
+ }
+ log.Printf("Channel accepted\n")
+
+ // Sessions have out-of-band requests such as "shell",
+ // "pty-req" and "env". Here we handle only the
+ // "subsystem" request.
+ go func(in <-chan *ssh.Request) {
+ for req := range in {
+ log.Printf("Request: %v\n", req.Type)
+ ok := false
+ switch req.Type {
+ case "subsystem":
+ log.Printf("Subsystem: %s\n", req.Payload[4:])
+ if string(req.Payload[4:]) == "sftp" {
+ ok = true
+ }
+ }
+ log.Printf(" - accepted: %v\n", ok)
+ req.Reply(ok, nil)
+ }
+ }(requests)
+
+ creds := keychain[user]
+ mc, err := minio.New(dconfig.Endpoint, &minio.Options{
+ Creds: credentials.NewStaticV4(creds.accessKey, creds.secretKey, ""),
+ Secure: dconfig.UseSSL,
+ })
+ if err != nil {
+ return
+ }
+
+ fs := s3.NewS3FS(mc)
+ server, err := sftp.NewServer(ctx, channel, &fs)
+
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ if err := server.Serve(); err == io.EOF {
+ server.Close()
+ log.Print("sftp client exited session.")
+ } else if err != nil {
+ log.Print("sftp server completed with error:", err)
+ }
+ }
+}
+
+func httpServer(config* Config, done chan error) {
// Assemble components to handle WebDAV requests
http.Handle(config.DavPath+"/",
BasicAuthExtract{
@@ -30,6 +183,8 @@ func main() {
})
if err := http.ListenAndServe(config.HttpListen, nil); err != nil {
- log.Fatalf("Error with WebDAV server: %v", err)
+ done <- fmt.Errorf("Error with WebDAV server: %v", err)
+ } else {
+ done <- nil
}
}
diff --git a/s3_file.go b/s3/file.go
index bbcfc8a..aa4f227 100644
--- a/s3_file.go
+++ b/s3/file.go
@@ -1,4 +1,4 @@
-package main
+package s3
import (
"context"
@@ -11,7 +11,6 @@ import (
"path"
"github.com/minio/minio-go/v7"
- "golang.org/x/net/webdav"
)
type S3File struct {
@@ -20,14 +19,16 @@ type S3File struct {
objw *io.PipeWriter
donew chan error
pos int64
- path S3Path
+ entries []fs.FileInfo
+ Path S3Path
}
-func NewS3File(s *S3FS, path string) (webdav.File, error) {
+func NewS3File(s *S3FS, path string) (*S3File, error) {
f := new(S3File)
f.fs = s
f.pos = 0
- f.path = NewS3Path(path)
+ f.entries = nil
+ f.Path = NewS3Path(path)
return f, nil
}
@@ -62,7 +63,7 @@ func (f *S3File) Close() error {
func (f *S3File) loadObject() error {
if f.obj == nil {
- obj, err := f.fs.mc.GetObject(f.fs.ctx, f.path.bucket, f.path.key, minio.GetObjectOptions{})
+ obj, err := f.fs.mc.GetObject(f.fs.ctx, f.Path.Bucket, f.Path.Key, minio.GetObjectOptions{})
if err != nil {
return err
}
@@ -82,6 +83,19 @@ func (f *S3File) Read(p []byte) (n int, err error) {
return f.obj.Read(p)
}
+func (f *S3File) ReadAt(p []byte, off int64) (n int, err error) {
+ if err := f.loadObject(); err != nil {
+ return 0, err
+ }
+
+ return f.obj.ReadAt(p, off)
+}
+
+func (f *S3File) WriteAt(p []byte, off int64) (n int, err error) {
+ return 0, errors.New("not implemented")
+
+}
+
func (f *S3File) Write(p []byte) (n int, err error) {
/*if f.path.class != OBJECT {
return 0, os.ErrInvalid
@@ -96,7 +110,7 @@ func (f *S3File) Write(p []byte) (n int, err error) {
f.donew = make(chan error, 1)
f.objw = w
- contentType := mime.TypeByExtension(path.Ext(f.path.key))
+ contentType := mime.TypeByExtension(path.Ext(f.Path.Key))
go func() {
/* @FIXME
PutObject has a strange behaviour when used with unknown size, it supposes the final size will be 5TiB.
@@ -107,7 +121,7 @@ func (f *S3File) Write(p []byte) (n int, err error) {
Because Multipart uploads seems to be limited to 10 000 parts, it might be possible that we are limited to 50 GiB files, which is still good enough.
Ref: https://github.com/minio/minio-go/blob/62beca8cd87e9960d88793320220ad2c159bb5e5/api-put-object-common.go#L110-L112
*/
- _, err := f.fs.mc.PutObject(context.Background(), f.path.bucket, f.path.key, r, -1, minio.PutObjectOptions{ContentType: contentType, PartSize: 5*1024*1024})
+ _, err := f.fs.mc.PutObject(context.Background(), f.Path.Bucket, f.Path.Key, r, -1, minio.PutObjectOptions{ContentType: contentType, PartSize: 5*1024*1024})
f.donew <- err
}()
}
@@ -133,43 +147,59 @@ If n > 0, ReadDir returns at most n DirEntry records. In this case, if ReadDir r
If n <= 0, ReadDir returns all the DirEntry records remaining in the directory. When it succeeds, it returns a nil error (not io.EOF).
*/
func (f *S3File) Readdir(count int) ([]fs.FileInfo, error) {
- if count > 0 {
- return nil, errors.New("returning a limited number of directory entry is not supported in readdir")
- }
-
- if f.path.class == ROOT {
+ if f.Path.Class == ROOT {
return f.readDirRoot(count)
} else {
return f.readDirChild(count)
}
}
-func (f *S3File) readDirRoot(count int) ([]fs.FileInfo, error) {
- buckets, err := f.fs.mc.ListBuckets(f.fs.ctx)
- if err != nil {
- return nil, err
- }
+func min(a, b int64) int64 {
+ if a < b {
+ return a
+ }
+ return b
+}
- entries := make([]fs.FileInfo, 0, len(buckets))
- for _, bucket := range buckets {
- //log.Println("Stat from GarageFile.readDirRoot()", "/"+bucket.Name)
- nf, err := NewS3Stat(f.fs, "/"+bucket.Name)
+func (f *S3File) readDirRoot(count int) ([]fs.FileInfo, error) {
+ var err error
+ if f.entries == nil {
+ buckets, err := f.fs.mc.ListBuckets(f.fs.ctx)
if err != nil {
return nil, err
}
- entries = append(entries, nf)
+
+ f.entries = make([]fs.FileInfo, 0, len(buckets))
+ for _, bucket := range buckets {
+ //log.Println("Stat from GarageFile.readDirRoot()", "/"+bucket.Name)
+ nf, err := NewS3Stat(f.fs, "/"+bucket.Name)
+ if err != nil {
+ return nil, err
+ }
+ f.entries = append(f.entries, nf)
+ }
+ }
+ beg := f.pos
+ end := int64(len(f.entries))
+ if count > 0 {
+ end = min(beg + int64(count), end)
}
+ f.pos = end
- return entries, nil
+ if end - beg == 0 {
+ err = io.EOF
+ }
+
+ return f.entries[beg:end], err
}
func (f *S3File) readDirChild(count int) ([]fs.FileInfo, error) {
- prefix := f.path.key
+ prefix := f.Path.Key
if len(prefix) > 0 && prefix[len(prefix)-1:] != "/" {
prefix = prefix + "/"
}
- objs_info := f.fs.mc.ListObjects(f.fs.ctx, f.path.bucket, minio.ListObjectsOptions{
+ objs_info := f.fs.mc.ListObjects(f.fs.ctx, f.Path.Bucket, minio.ListObjectsOptions{
Prefix: prefix,
Recursive: false,
})
@@ -180,7 +210,7 @@ func (f *S3File) readDirChild(count int) ([]fs.FileInfo, error) {
return nil, object.Err
}
//log.Println("Stat from GarageFile.readDirChild()", path.Join("/", f.path.bucket, object.Key))
- nf, err := NewS3StatFromObjectInfo(f.fs, f.path.bucket, object)
+ nf, err := NewS3StatFromObjectInfo(f.fs, f.Path.Bucket, object)
if err != nil {
return nil, err
}
@@ -191,5 +221,5 @@ func (f *S3File) readDirChild(count int) ([]fs.FileInfo, error) {
}
func (f *S3File) Stat() (fs.FileInfo, error) {
- return NewS3Stat(f.fs, f.path.path)
+ return NewS3Stat(f.fs, f.Path.Path)
}
diff --git a/s3_fs.go b/s3/fs.go
index e93edb3..c5ae6a0 100644
--- a/s3_fs.go
+++ b/s3/fs.go
@@ -1,4 +1,4 @@
-package main
+package s3
import (
"context"
@@ -37,9 +37,9 @@ func (s S3FS) Mkdir(ctx context.Context, name string, perm os.FileMode) error {
p := NewS3Path(name)
- if p.class == ROOT {
+ if p.Class == ROOT {
return errors.New("Unable to create another root folder")
- } else if p.class == BUCKET {
+ } else if p.Class == BUCKET {
log.Println("Creating bucket is not implemented yet")
return nil
}
@@ -54,7 +54,7 @@ func (s S3FS) Mkdir(ctx context.Context, name string, perm os.FileMode) error {
return nil
}
-func (s S3FS) OpenFile(ctx context.Context, name string, flag int, perm os.FileMode) (webdav.File, error) {
+func (s S3FS) OpenFile2(ctx context.Context, name string, flag int, perm os.FileMode) (*S3File, error) {
s.ctx = ctx
// If the file does not exist when opening it, we create a stub
@@ -62,8 +62,8 @@ func (s S3FS) OpenFile(ctx context.Context, name string, flag int, perm os.FileM
st := new(S3Stat)
st.fs = &s
st.path = NewS3Path(name)
- st.path.class = OBJECT
- st.obj.Key = st.path.key
+ st.path.Class = OBJECT
+ st.obj.Key = st.path.Key
st.obj.LastModified = time.Now()
s.cache[name] = st
}
@@ -71,20 +71,24 @@ func (s S3FS) OpenFile(ctx context.Context, name string, flag int, perm os.FileM
return NewS3File(&s, name)
}
+func (s S3FS) OpenFile(ctx context.Context, name string, flag int, perm os.FileMode) (webdav.File, error) {
+ return s.OpenFile2(ctx, name, flag, perm)
+}
+
func (s S3FS) RemoveAll(ctx context.Context, name string) error {
//@FIXME nautilus deletes files one by one, at the end, it does not find its folder as it is "already deleted"
s.ctx = ctx
p := NewS3Path(name)
- if p.class == ROOT {
+ if p.Class == ROOT {
return errors.New("Unable to create another root folder")
- } else if p.class == BUCKET {
+ } else if p.Class == BUCKET {
log.Println("Deleting bucket is not implemented yet")
return nil
}
- objCh := s.mc.ListObjects(s.ctx, p.bucket, minio.ListObjectsOptions{Prefix: p.key, Recursive: true})
- rmCh := s.mc.RemoveObjects(s.ctx, p.bucket, objCh, minio.RemoveObjectsOptions{})
+ objCh := s.mc.ListObjects(s.ctx, p.Bucket, minio.ListObjectsOptions{Prefix: p.Key, Recursive: true})
+ rmCh := s.mc.RemoveObjects(s.ctx, p.Bucket, objCh, minio.RemoveObjectsOptions{})
for rErr := range rmCh {
return rErr.Err
@@ -98,9 +102,9 @@ func (s S3FS) Rename(ctx context.Context, oldName, newName string) error {
po := NewS3Path(oldName)
pn := NewS3Path(newName)
- if po.class == ROOT || pn.class == ROOT {
+ if po.Class == ROOT || pn.Class == ROOT {
return errors.New("Unable to rename root folder")
- } else if po.class == BUCKET || pn.class == BUCKET {
+ } else if po.Class == BUCKET || pn.Class == BUCKET {
log.Println("Moving a bucket is not implemented yet")
return nil
}
@@ -111,16 +115,16 @@ func (s S3FS) Rename(ctx context.Context, oldName, newName string) error {
}
//Gather all keys, copy the object, delete the original
- objCh := s.mc.ListObjects(s.ctx, po.bucket, minio.ListObjectsOptions{Prefix: po.key, Recursive: true})
+ objCh := s.mc.ListObjects(s.ctx, po.Bucket, minio.ListObjectsOptions{Prefix: po.Key, Recursive: true})
for obj := range objCh {
src := minio.CopySrcOptions{
- Bucket: po.bucket,
+ Bucket: po.Bucket,
Object: obj.Key,
}
dst := minio.CopyDestOptions{
- Bucket: pn.bucket,
- Object: path.Join(pn.key, obj.Key[len(po.key):]),
+ Bucket: pn.Bucket,
+ Object: path.Join(pn.Key, obj.Key[len(po.Key):]),
}
_, err := s.mc.CopyObject(s.ctx, dst, src)
@@ -128,7 +132,7 @@ func (s S3FS) Rename(ctx context.Context, oldName, newName string) error {
return err
}
- err = s.mc.RemoveObject(s.ctx, po.bucket, obj.Key, minio.RemoveObjectOptions{})
+ err = s.mc.RemoveObject(s.ctx, po.Bucket, obj.Key, minio.RemoveObjectOptions{})
var e minio.ErrorResponse
log.Println(errors.As(err, &e))
log.Println(e)
diff --git a/s3/path.go b/s3/path.go
new file mode 100644
index 0000000..b1a981b
--- /dev/null
+++ b/s3/path.go
@@ -0,0 +1,68 @@
+package s3
+
+import (
+ "path"
+ "strings"
+
+ "github.com/minio/minio-go/v7"
+)
+
+type S3Class int
+
+const (
+ ROOT S3Class = 1 << iota
+ BUCKET
+ COMMON_PREFIX
+ OBJECT
+ OPAQUE_KEY
+
+ KEY = COMMON_PREFIX | OBJECT | OPAQUE_KEY
+)
+
+type S3Path struct {
+ Path string
+ Class S3Class
+ Bucket string
+ Key string
+}
+
+func NewS3Path(p string) S3Path {
+ // Remove first dot, eq. relative directory == "/"
+ if len(p) > 0 && p[0] == '.' {
+ p = p[1:]
+ }
+
+ // Add the first slash if missing
+ p = "/" + p
+
+ // Clean path using golang tools
+ p = path.Clean(p)
+
+ exploded_path := strings.SplitN(p, "/", 3)
+
+ // If there is no bucket name (eg. "/")
+ if len(exploded_path) < 2 || exploded_path[1] == "" {
+ return S3Path{p, ROOT, "", ""}
+ }
+
+ // If there is no key
+ if len(exploded_path) < 3 || exploded_path[2] == "" {
+ return S3Path{p, BUCKET, exploded_path[1], ""}
+ }
+
+ return S3Path{p, OPAQUE_KEY, exploded_path[1], exploded_path[2]}
+}
+
+func NewTrustedS3Path(bucket string, obj minio.ObjectInfo) S3Path {
+ cl := OBJECT
+ if obj.Key[len(obj.Key)-1:] == "/" {
+ cl = COMMON_PREFIX
+ }
+
+ return S3Path{
+ Path: path.Join("/", bucket, obj.Key),
+ Bucket: bucket,
+ Key: obj.Key,
+ Class: cl,
+ }
+}
diff --git a/s3_stat.go b/s3/stat.go
index 65b6faa..96b0c24 100644
--- a/s3_stat.go
+++ b/s3/stat.go
@@ -1,4 +1,4 @@
-package main
+package s3
import (
"errors"
@@ -24,7 +24,7 @@ func NewS3StatFromObjectInfo(fs *S3FS, bucket string, obj minio.ObjectInfo) (*S3
s.obj = obj
s.fs = fs
- fs.cache[s.path.path] = s
+ fs.cache[s.path.Path] = s
return s, nil
}
@@ -44,30 +44,30 @@ func NewS3Stat(fs *S3FS, path string) (*S3Stat, error) {
return nil, err
}
- if s.path.class&OPAQUE_KEY != 0 {
+ if s.path.Class&OPAQUE_KEY != 0 {
return nil, errors.New("Failed to precisely determine the key type, this a logic error.")
}
cache[path] = s
- cache[s.path.path] = s
+ cache[s.path.Path] = s
return s, nil
}
func (s *S3Stat) Refresh() error {
- if s.path.class == ROOT || s.path.class == BUCKET {
+ if s.path.Class == ROOT || s.path.Class == BUCKET {
return nil
}
mc := s.fs.mc
// Compute the prefix to have the desired behaviour for our stat logic
- prefix := s.path.key
+ prefix := s.path.Key
if prefix[len(prefix)-1:] == "/" {
prefix = prefix[:len(prefix)-1]
}
// Get info and check if the key exists
- objs_info := mc.ListObjects(s.fs.ctx, s.path.bucket, minio.ListObjectsOptions{
+ objs_info := mc.ListObjects(s.fs.ctx, s.path.Bucket, minio.ListObjectsOptions{
Prefix: prefix,
Recursive: false,
})
@@ -80,7 +80,7 @@ func (s *S3Stat) Refresh() error {
if object.Key == prefix || object.Key == prefix+"/" {
s.obj = object
- s.path = NewTrustedS3Path(s.path.bucket, object)
+ s.path = NewTrustedS3Path(s.path.Bucket, object)
found = true
break
}
@@ -94,12 +94,12 @@ func (s *S3Stat) Refresh() error {
}
func (s *S3Stat) Name() string {
- if s.path.class == ROOT {
+ if s.path.Class == ROOT {
return "/"
- } else if s.path.class == BUCKET {
- return s.path.bucket
+ } else if s.path.Class == BUCKET {
+ return s.path.Bucket
} else {
- return path.Base(s.path.key)
+ return path.Base(s.path.Key)
}
}
@@ -108,7 +108,7 @@ func (s *S3Stat) Size() int64 {
}
func (s *S3Stat) Mode() fs.FileMode {
- if s.path.class == OBJECT {
+ if s.path.Class == OBJECT {
return fs.ModePerm
} else {
return fs.ModeDir | fs.ModePerm
@@ -120,7 +120,7 @@ func (s *S3Stat) ModTime() time.Time {
}
func (s *S3Stat) IsDir() bool {
- return s.path.class != OBJECT
+ return s.path.Class != OBJECT
}
func (s *S3Stat) Sys() interface{} {
diff --git a/s3_path.go b/s3_path.go
deleted file mode 100644
index e1933ef..0000000
--- a/s3_path.go
+++ /dev/null
@@ -1,57 +0,0 @@
-package main
-
-import (
- "path"
- "strings"
-
- "github.com/minio/minio-go/v7"
-)
-
-type S3Class int
-
-const (
- ROOT S3Class = 1 << iota
- BUCKET
- COMMON_PREFIX
- OBJECT
- OPAQUE_KEY
-
- KEY = COMMON_PREFIX | OBJECT | OPAQUE_KEY
-)
-
-type S3Path struct {
- path string
- class S3Class
- bucket string
- key string
-}
-
-func NewS3Path(path string) S3Path {
- exploded_path := strings.SplitN(path, "/", 3)
-
- // If there is no bucket name (eg. "/")
- if len(exploded_path) < 2 || exploded_path[1] == "" {
- return S3Path{path, ROOT, "", ""}
- }
-
- // If there is no key
- if len(exploded_path) < 3 || exploded_path[2] == "" {
- return S3Path{path, BUCKET, exploded_path[1], ""}
- }
-
- return S3Path{path, OPAQUE_KEY, exploded_path[1], exploded_path[2]}
-}
-
-func NewTrustedS3Path(bucket string, obj minio.ObjectInfo) S3Path {
- cl := OBJECT
- if obj.Key[len(obj.Key)-1:] == "/" {
- cl = COMMON_PREFIX
- }
-
- return S3Path{
- path: path.Join("/", bucket, obj.Key),
- bucket: bucket,
- key: obj.Key,
- class: cl,
- }
-}
diff --git a/sftp/allocator.go b/sftp/allocator.go
new file mode 100644
index 0000000..fc1b6f0
--- /dev/null
+++ b/sftp/allocator.go
@@ -0,0 +1,101 @@
+package sftp
+
+/*
+ Imported from: https://github.com/pkg/sftp
+ */
+
+import (
+ "sync"
+)
+
+type allocator struct {
+ sync.Mutex
+ available [][]byte
+ // map key is the request order
+ used map[uint32][][]byte
+}
+
+func newAllocator() *allocator {
+ return &allocator{
+ // micro optimization: initialize available pages with an initial capacity
+ available: make([][]byte, 0, SftpServerWorkerCount*2),
+ used: make(map[uint32][][]byte),
+ }
+}
+
+// GetPage returns a previously allocated and unused []byte or create a new one.
+// The slice have a fixed size = maxMsgLength, this value is suitable for both
+// receiving new packets and reading the files to serve
+func (a *allocator) GetPage(requestOrderID uint32) []byte {
+ a.Lock()
+ defer a.Unlock()
+
+ var result []byte
+
+ // get an available page and remove it from the available ones.
+ if len(a.available) > 0 {
+ truncLength := len(a.available) - 1
+ result = a.available[truncLength]
+
+ a.available[truncLength] = nil // clear out the internal pointer
+ a.available = a.available[:truncLength] // truncate the slice
+ }
+
+ // no preallocated slice found, just allocate a new one
+ if result == nil {
+ result = make([]byte, maxMsgLength)
+ }
+
+ // put result in used pages
+ a.used[requestOrderID] = append(a.used[requestOrderID], result)
+
+ return result
+}
+
+// ReleasePages marks unused all pages in use for the given requestID
+func (a *allocator) ReleasePages(requestOrderID uint32) {
+ a.Lock()
+ defer a.Unlock()
+
+ if used := a.used[requestOrderID]; len(used) > 0 {
+ a.available = append(a.available, used...)
+ }
+ delete(a.used, requestOrderID)
+}
+
+// Free removes all the used and available pages.
+// Call this method when the allocator is not needed anymore
+func (a *allocator) Free() {
+ a.Lock()
+ defer a.Unlock()
+
+ a.available = nil
+ a.used = make(map[uint32][][]byte)
+}
+
+func (a *allocator) countUsedPages() int {
+ a.Lock()
+ defer a.Unlock()
+
+ num := 0
+ for _, p := range a.used {
+ num += len(p)
+ }
+ return num
+}
+
+func (a *allocator) countAvailablePages() int {
+ a.Lock()
+ defer a.Unlock()
+
+ return len(a.available)
+}
+
+func (a *allocator) isRequestOrderIDUsed(requestOrderID uint32) bool {
+ a.Lock()
+ defer a.Unlock()
+
+ _, ok := a.used[requestOrderID]
+ return ok
+}
+
diff --git a/sftp/attrs.go b/sftp/attrs.go
new file mode 100644
index 0000000..2bb2d57
--- /dev/null
+++ b/sftp/attrs.go
@@ -0,0 +1,90 @@
+package sftp
+
+// ssh_FXP_ATTRS support
+// see http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-5
+
+import (
+ "os"
+ "time"
+)
+
+const (
+ sshFileXferAttrSize = 0x00000001
+ sshFileXferAttrUIDGID = 0x00000002
+ sshFileXferAttrPermissions = 0x00000004
+ sshFileXferAttrACmodTime = 0x00000008
+ sshFileXferAttrExtended = 0x80000000
+
+ sshFileXferAttrAll = sshFileXferAttrSize | sshFileXferAttrUIDGID | sshFileXferAttrPermissions |
+ sshFileXferAttrACmodTime | sshFileXferAttrExtended
+)
+
+// fileInfo is an artificial type designed to satisfy os.FileInfo.
+type fileInfo struct {
+ name string
+ stat *FileStat
+}
+
+// Name returns the base name of the file.
+func (fi *fileInfo) Name() string { return fi.name }
+
+// Size returns the length in bytes for regular files; system-dependent for others.
+func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) }
+
+// Mode returns file mode bits.
+func (fi *fileInfo) Mode() os.FileMode { return toFileMode(fi.stat.Mode) }
+
+// ModTime returns the last modification time of the file.
+func (fi *fileInfo) ModTime() time.Time { return time.Unix(int64(fi.stat.Mtime), 0) }
+
+// IsDir returns true if the file is a directory.
+func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() }
+
+func (fi *fileInfo) Sys() interface{} { return fi.stat }
+
+// FileStat holds the original unmarshalled values from a call to READDIR or
+// *STAT. It is exported for the purposes of accessing the raw values via
+// os.FileInfo.Sys(). It is also used server side to store the unmarshalled
+// values for SetStat.
+type FileStat struct {
+ Size uint64
+ Mode uint32
+ Mtime uint32
+ Atime uint32
+ UID uint32
+ GID uint32
+ Extended []StatExtended
+}
+
+// StatExtended contains additional, extended information for a FileStat.
+type StatExtended struct {
+ ExtType string
+ ExtData string
+}
+
+func fileInfoFromStat(stat *FileStat, name string) os.FileInfo {
+ return &fileInfo{
+ name: name,
+ stat: stat,
+ }
+}
+
+func fileStatFromInfo(fi os.FileInfo) (uint32, *FileStat) {
+ mtime := fi.ModTime().Unix()
+ atime := mtime
+ var flags uint32 = sshFileXferAttrSize |
+ sshFileXferAttrPermissions |
+ sshFileXferAttrACmodTime
+
+ fileStat := &FileStat{
+ Size: uint64(fi.Size()),
+ Mode: fromFileMode(fi.Mode()),
+ Mtime: uint32(mtime),
+ Atime: uint32(atime),
+ }
+
+ // os specific file stat decoding
+ fileStatFromInfoOs(fi, &flags, fileStat)
+
+ return flags, fileStat
+}
diff --git a/sftp/attrs_stubs.go b/sftp/attrs_stubs.go
new file mode 100644
index 0000000..c01f336
--- /dev/null
+++ b/sftp/attrs_stubs.go
@@ -0,0 +1,11 @@
+// +build plan9 windows android
+
+package sftp
+
+import (
+ "os"
+)
+
+func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) {
+ // todo
+}
diff --git a/sftp/attrs_unix.go b/sftp/attrs_unix.go
new file mode 100644
index 0000000..d1f4452
--- /dev/null
+++ b/sftp/attrs_unix.go
@@ -0,0 +1,16 @@
+// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js
+
+package sftp
+
+import (
+ "os"
+ "syscall"
+)
+
+func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) {
+ if statt, ok := fi.Sys().(*syscall.Stat_t); ok {
+ *flags |= sshFileXferAttrUIDGID
+ fileStat.UID = statt.Uid
+ fileStat.GID = statt.Gid
+ }
+}
diff --git a/sftp/client.go b/sftp/client.go
new file mode 100644
index 0000000..ce62286
--- /dev/null
+++ b/sftp/client.go
@@ -0,0 +1,1936 @@
+package sftp
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "math"
+ "os"
+ "path"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "github.com/kr/fs"
+ "golang.org/x/crypto/ssh"
+)
+
+var (
+ // ErrInternalInconsistency indicates the packets sent and the data queued to be
+ // written to the file don't match up. It is an unusual error and usually is
+ // caused by bad behavior server side or connection issues. The error is
+ // limited in scope to the call where it happened, the client object is still
+ // OK to use as long as the connection is still open.
+ ErrInternalInconsistency = errors.New("internal inconsistency")
+ // InternalInconsistency alias for ErrInternalInconsistency.
+ //
+ // Deprecated: please use ErrInternalInconsistency
+ InternalInconsistency = ErrInternalInconsistency
+)
+
+// A ClientOption is a function which applies configuration to a Client.
+type ClientOption func(*Client) error
+
+// MaxPacketChecked sets the maximum size of the payload, measured in bytes.
+// This option only accepts sizes servers should support, ie. <= 32768 bytes.
+//
+// If you get the error "failed to send packet header: EOF" when copying a
+// large file, try lowering this number.
+//
+// The default packet size is 32768 bytes.
+func MaxPacketChecked(size int) ClientOption {
+ return func(c *Client) error {
+ if size < 1 {
+ return errors.New("size must be greater or equal to 1")
+ }
+ if size > 32768 {
+ return errors.New("sizes larger than 32KB might not work with all servers")
+ }
+ c.maxPacket = size
+ return nil
+ }
+}
+
+// MaxPacketUnchecked sets the maximum size of the payload, measured in bytes.
+// It accepts sizes larger than the 32768 bytes all servers should support.
+// Only use a setting higher than 32768 if your application always connects to
+// the same server or after sufficiently broad testing.
+//
+// If you get the error "failed to send packet header: EOF" when copying a
+// large file, try lowering this number.
+//
+// The default packet size is 32768 bytes.
+func MaxPacketUnchecked(size int) ClientOption {
+ return func(c *Client) error {
+ if size < 1 {
+ return errors.New("size must be greater or equal to 1")
+ }
+ c.maxPacket = size
+ return nil
+ }
+}
+
+// MaxPacket sets the maximum size of the payload, measured in bytes.
+// This option only accepts sizes servers should support, ie. <= 32768 bytes.
+// This is a synonym for MaxPacketChecked that provides backward compatibility.
+//
+// If you get the error "failed to send packet header: EOF" when copying a
+// large file, try lowering this number.
+//
+// The default packet size is 32768 bytes.
+func MaxPacket(size int) ClientOption {
+ return MaxPacketChecked(size)
+}
+
+// MaxConcurrentRequestsPerFile sets the maximum concurrent requests allowed for a single file.
+//
+// The default maximum concurrent requests is 64.
+func MaxConcurrentRequestsPerFile(n int) ClientOption {
+ return func(c *Client) error {
+ if n < 1 {
+ return errors.New("n must be greater or equal to 1")
+ }
+ c.maxConcurrentRequests = n
+ return nil
+ }
+}
+
+// UseConcurrentWrites allows the Client to perform concurrent Writes.
+//
+// Using concurrency while doing writes, requires special consideration.
+// A write to a later offset in a file after an error,
+// could end up with a file length longer than what was successfully written.
+//
+// When using this option, if you receive an error during `io.Copy` or `io.WriteTo`,
+// you may need to `Truncate` the target Writer to avoid “holes” in the data written.
+func UseConcurrentWrites(value bool) ClientOption {
+ return func(c *Client) error {
+ c.useConcurrentWrites = value
+ return nil
+ }
+}
+
+// UseConcurrentReads allows the Client to perform concurrent Reads.
+//
+// Concurrent reads are generally safe to use and not using them will degrade
+// performance, so this option is enabled by default.
+//
+// When enabled, WriteTo will use Stat/Fstat to get the file size and determines
+// how many concurrent workers to use.
+// Some "read once" servers will delete the file if they receive a stat call on an
+// open file and then the download will fail.
+// Disabling concurrent reads you will be able to download files from these servers.
+// If concurrent reads are disabled, the UseFstat option is ignored.
+func UseConcurrentReads(value bool) ClientOption {
+ return func(c *Client) error {
+ c.disableConcurrentReads = !value
+ return nil
+ }
+}
+
+// UseFstat sets whether to use Fstat or Stat when File.WriteTo is called
+// (usually when copying files).
+// Some servers limit the amount of open files and calling Stat after opening
+// the file will throw an error From the server. Setting this flag will call
+// Fstat instead of Stat which is suppose to be called on an open file handle.
+//
+// It has been found that that with IBM Sterling SFTP servers which have
+// "extractability" level set to 1 which means only 1 file can be opened at
+// any given time.
+//
+// If the server you are working with still has an issue with both Stat and
+// Fstat calls you can always open a file and read it until the end.
+//
+// Another reason to read the file until its end and Fstat doesn't work is
+// that in some servers, reading a full file will automatically delete the
+// file as some of these mainframes map the file to a message in a queue.
+// Once the file has been read it will get deleted.
+func UseFstat(value bool) ClientOption {
+ return func(c *Client) error {
+ c.useFstat = value
+ return nil
+ }
+}
+
+// Client represents an SFTP session on a *ssh.ClientConn SSH connection.
+// Multiple Clients can be active on a single SSH connection, and a Client
+// may be called concurrently from multiple Goroutines.
+//
+// Client implements the github.com/kr/fs.FileSystem interface.
+type Client struct {
+ clientConn
+
+ ext map[string]string // Extensions (name -> data).
+
+ maxPacket int // max packet size read or written.
+ maxConcurrentRequests int
+ nextid uint32
+
+ // write concurrency is… error prone.
+ // Default behavior should be to not use it.
+ useConcurrentWrites bool
+ useFstat bool
+ disableConcurrentReads bool
+}
+
+// NewClient creates a new SFTP client on conn, using zero or more option
+// functions.
+func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) {
+ s, err := conn.NewSession()
+ if err != nil {
+ return nil, err
+ }
+ if err := s.RequestSubsystem("sftp"); err != nil {
+ return nil, err
+ }
+ pw, err := s.StdinPipe()
+ if err != nil {
+ return nil, err
+ }
+ pr, err := s.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+
+ return NewClientPipe(pr, pw, opts...)
+}
+
+// NewClientPipe creates a new SFTP client given a Reader and a WriteCloser.
+// This can be used for connecting to an SFTP server over TCP/TLS or by using
+// the system's ssh client program (e.g. via exec.Command).
+func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) {
+ sftp := &Client{
+ clientConn: clientConn{
+ conn: conn{
+ Reader: rd,
+ WriteCloser: wr,
+ },
+ inflight: make(map[uint32]chan<- result),
+ closed: make(chan struct{}),
+ },
+
+ ext: make(map[string]string),
+
+ maxPacket: 1 << 15,
+ maxConcurrentRequests: 64,
+ }
+
+ for _, opt := range opts {
+ if err := opt(sftp); err != nil {
+ wr.Close()
+ return nil, err
+ }
+ }
+
+ if err := sftp.sendInit(); err != nil {
+ wr.Close()
+ return nil, err
+ }
+ if err := sftp.recvVersion(); err != nil {
+ wr.Close()
+ return nil, err
+ }
+
+ sftp.clientConn.wg.Add(1)
+ go sftp.loop()
+
+ return sftp, nil
+}
+
+// Create creates the named file mode 0666 (before umask), truncating it if it
+// already exists. If successful, methods on the returned File can be used for
+// I/O; the associated file descriptor has mode O_RDWR. If you need more
+// control over the flags/mode used to open the file see client.OpenFile.
+//
+// Note that some SFTP servers (eg. AWS Transfer) do not support opening files
+// read/write at the same time. For those services you will need to use
+// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
+func (c *Client) Create(path string) (*File, error) {
+ return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
+}
+
+const sftpProtocolVersion = 3 // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
+
+func (c *Client) sendInit() error {
+ return c.clientConn.conn.sendPacket(&sshFxInitPacket{
+ Version: sftpProtocolVersion, // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
+ })
+}
+
+// returns the next value of c.nextid
+func (c *Client) nextID() uint32 {
+ return atomic.AddUint32(&c.nextid, 1)
+}
+
+func (c *Client) recvVersion() error {
+ typ, data, err := c.recvPacket(0)
+ if err != nil {
+ return err
+ }
+ if typ != sshFxpVersion {
+ return &unexpectedPacketErr{sshFxpVersion, typ}
+ }
+
+ version, data, err := unmarshalUint32Safe(data)
+ if err != nil {
+ return err
+ }
+ if version != sftpProtocolVersion {
+ return &unexpectedVersionErr{sftpProtocolVersion, version}
+ }
+
+ for len(data) > 0 {
+ var ext extensionPair
+ ext, data, err = unmarshalExtensionPair(data)
+ if err != nil {
+ return err
+ }
+ c.ext[ext.Name] = ext.Data
+ }
+
+ return nil
+}
+
+// HasExtension checks whether the server supports a named extension.
+//
+// The first return value is the extension data reported by the server
+// (typically a version number).
+func (c *Client) HasExtension(name string) (string, bool) {
+ data, ok := c.ext[name]
+ return data, ok
+}
+
+// Walk returns a new Walker rooted at root.
+func (c *Client) Walk(root string) *fs.Walker {
+ return fs.WalkFS(root, c)
+}
+
+// ReadDir reads the directory named by dirname and returns a list of
+// directory entries.
+func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
+ handle, err := c.opendir(p)
+ if err != nil {
+ return nil, err
+ }
+ defer c.close(handle) // this has to defer earlier than the lock below
+ var attrs []os.FileInfo
+ var done = false
+ for !done {
+ id := c.nextID()
+ typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{
+ ID: id,
+ Handle: handle,
+ })
+ if err1 != nil {
+ err = err1
+ done = true
+ break
+ }
+ switch typ {
+ case sshFxpName:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return nil, &unexpectedIDErr{id, sid}
+ }
+ count, data := unmarshalUint32(data)
+ for i := uint32(0); i < count; i++ {
+ var filename string
+ filename, data = unmarshalString(data)
+ _, data = unmarshalString(data) // discard longname
+ var attr *FileStat
+ attr, data = unmarshalAttrs(data)
+ if filename == "." || filename == ".." {
+ continue
+ }
+ attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename)))
+ }
+ case sshFxpStatus:
+ // TODO(dfc) scope warning!
+ err = normaliseError(unmarshalStatus(id, data))
+ done = true
+ default:
+ return nil, unimplementedPacketErr(typ)
+ }
+ }
+ if err == io.EOF {
+ err = nil
+ }
+ return attrs, err
+}
+
+func (c *Client) opendir(path string) (string, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{
+ ID: id,
+ Path: path,
+ })
+ if err != nil {
+ return "", err
+ }
+ switch typ {
+ case sshFxpHandle:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return "", &unexpectedIDErr{id, sid}
+ }
+ handle, _ := unmarshalString(data)
+ return handle, nil
+ case sshFxpStatus:
+ return "", normaliseError(unmarshalStatus(id, data))
+ default:
+ return "", unimplementedPacketErr(typ)
+ }
+}
+
+// Stat returns a FileInfo structure describing the file specified by path 'p'.
+// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file.
+func (c *Client) Stat(p string) (os.FileInfo, error) {
+ fs, err := c.stat(p)
+ if err != nil {
+ return nil, err
+ }
+ return fileInfoFromStat(fs, path.Base(p)), nil
+}
+
+// Lstat returns a FileInfo structure describing the file specified by path 'p'.
+// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link.
+func (c *Client) Lstat(p string) (os.FileInfo, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{
+ ID: id,
+ Path: p,
+ })
+ if err != nil {
+ return nil, err
+ }
+ switch typ {
+ case sshFxpAttrs:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return nil, &unexpectedIDErr{id, sid}
+ }
+ attr, _ := unmarshalAttrs(data)
+ return fileInfoFromStat(attr, path.Base(p)), nil
+ case sshFxpStatus:
+ return nil, normaliseError(unmarshalStatus(id, data))
+ default:
+ return nil, unimplementedPacketErr(typ)
+ }
+}
+
+// ReadLink reads the target of a symbolic link.
+func (c *Client) ReadLink(p string) (string, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{
+ ID: id,
+ Path: p,
+ })
+ if err != nil {
+ return "", err
+ }
+ switch typ {
+ case sshFxpName:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return "", &unexpectedIDErr{id, sid}
+ }
+ count, data := unmarshalUint32(data)
+ if count != 1 {
+ return "", unexpectedCount(1, count)
+ }
+ filename, _ := unmarshalString(data) // ignore dummy attributes
+ return filename, nil
+ case sshFxpStatus:
+ return "", normaliseError(unmarshalStatus(id, data))
+ default:
+ return "", unimplementedPacketErr(typ)
+ }
+}
+
+// Link creates a hard link at 'newname', pointing at the same inode as 'oldname'
+func (c *Client) Link(oldname, newname string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{
+ ID: id,
+ Oldpath: oldname,
+ Newpath: newname,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// Symlink creates a symbolic link at 'newname', pointing at target 'oldname'
+func (c *Client) Symlink(oldname, newname string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{
+ ID: id,
+ Linkpath: newname,
+ Targetpath: oldname,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{
+ ID: id,
+ Handle: handle,
+ Flags: flags,
+ Attrs: attrs,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
+func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{
+ ID: id,
+ Path: path,
+ Flags: flags,
+ Attrs: attrs,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// Chtimes changes the access and modification times of the named file.
+func (c *Client) Chtimes(path string, atime time.Time, mtime time.Time) error {
+ type times struct {
+ Atime uint32
+ Mtime uint32
+ }
+ attrs := times{uint32(atime.Unix()), uint32(mtime.Unix())}
+ return c.setstat(path, sshFileXferAttrACmodTime, attrs)
+}
+
+// Chown changes the user and group owners of the named file.
+func (c *Client) Chown(path string, uid, gid int) error {
+ type owner struct {
+ UID uint32
+ GID uint32
+ }
+ attrs := owner{uint32(uid), uint32(gid)}
+ return c.setstat(path, sshFileXferAttrUIDGID, attrs)
+}
+
+// Chmod changes the permissions of the named file.
+//
+// Chmod does not apply a umask, because even retrieving the umask is not
+// possible in a portable way without causing a race condition. Callers
+// should mask off umask bits, if desired.
+func (c *Client) Chmod(path string, mode os.FileMode) error {
+ return c.setstat(path, sshFileXferAttrPermissions, toChmodPerm(mode))
+}
+
+// Truncate sets the size of the named file. Although it may be safely assumed
+// that if the size is less than its current size it will be truncated to fit,
+// the SFTP protocol does not specify what behavior the server should do when setting
+// size greater than the current size.
+func (c *Client) Truncate(path string, size int64) error {
+ return c.setstat(path, sshFileXferAttrSize, uint64(size))
+}
+
+// Open opens the named file for reading. If successful, methods on the
+// returned file can be used for reading; the associated file descriptor
+// has mode O_RDONLY.
+func (c *Client) Open(path string) (*File, error) {
+ return c.open(path, flags(os.O_RDONLY))
+}
+
+// OpenFile is the generalized open call; most users will use Open or
+// Create instead. It opens the named file with specified flag (O_RDONLY
+// etc.). If successful, methods on the returned File can be used for I/O.
+func (c *Client) OpenFile(path string, f int) (*File, error) {
+ return c.open(path, flags(f))
+}
+
+func (c *Client) open(path string, pflags uint32) (*File, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{
+ ID: id,
+ Path: path,
+ Pflags: pflags,
+ })
+ if err != nil {
+ return nil, err
+ }
+ switch typ {
+ case sshFxpHandle:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return nil, &unexpectedIDErr{id, sid}
+ }
+ handle, _ := unmarshalString(data)
+ return &File{c: c, path: path, handle: handle}, nil
+ case sshFxpStatus:
+ return nil, normaliseError(unmarshalStatus(id, data))
+ default:
+ return nil, unimplementedPacketErr(typ)
+ }
+}
+
+// close closes a handle handle previously returned in the response
+// to SSH_FXP_OPEN or SSH_FXP_OPENDIR. The handle becomes invalid
+// immediately after this request has been sent.
+func (c *Client) close(handle string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{
+ ID: id,
+ Handle: handle,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+func (c *Client) stat(path string) (*FileStat, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
+ ID: id,
+ Path: path,
+ })
+ if err != nil {
+ return nil, err
+ }
+ switch typ {
+ case sshFxpAttrs:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return nil, &unexpectedIDErr{id, sid}
+ }
+ attr, _ := unmarshalAttrs(data)
+ return attr, nil
+ case sshFxpStatus:
+ return nil, normaliseError(unmarshalStatus(id, data))
+ default:
+ return nil, unimplementedPacketErr(typ)
+ }
+}
+
+func (c *Client) fstat(handle string) (*FileStat, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{
+ ID: id,
+ Handle: handle,
+ })
+ if err != nil {
+ return nil, err
+ }
+ switch typ {
+ case sshFxpAttrs:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return nil, &unexpectedIDErr{id, sid}
+ }
+ attr, _ := unmarshalAttrs(data)
+ return attr, nil
+ case sshFxpStatus:
+ return nil, normaliseError(unmarshalStatus(id, data))
+ default:
+ return nil, unimplementedPacketErr(typ)
+ }
+}
+
+// StatVFS retrieves VFS statistics from a remote host.
+//
+// It implements the statvfs@openssh.com SSH_FXP_EXTENDED feature
+// from http://www.opensource.apple.com/source/OpenSSH/OpenSSH-175/openssh/PROTOCOL?txt.
+func (c *Client) StatVFS(path string) (*StatVFS, error) {
+ // send the StatVFS packet to the server
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{
+ ID: id,
+ Path: path,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ switch typ {
+ // server responded with valid data
+ case sshFxpExtendedReply:
+ var response StatVFS
+ err = binary.Read(bytes.NewReader(data), binary.BigEndian, &response)
+ if err != nil {
+ return nil, errors.New("can not parse reply")
+ }
+
+ return &response, nil
+
+ // the resquest failed
+ case sshFxpStatus:
+ return nil, normaliseError(unmarshalStatus(id, data))
+
+ default:
+ return nil, unimplementedPacketErr(typ)
+ }
+}
+
+// Join joins any number of path elements into a single path, adding a
+// separating slash if necessary. The result is Cleaned; in particular, all
+// empty strings are ignored.
+func (c *Client) Join(elem ...string) string { return path.Join(elem...) }
+
+// Remove removes the specified file or directory. An error will be returned if no
+// file or directory with the specified path exists, or if the specified directory
+// is not empty.
+func (c *Client) Remove(path string) error {
+ err := c.removeFile(path)
+ // some servers, *cough* osx *cough*, return EPERM, not ENODIR.
+ // serv-u returns ssh_FX_FILE_IS_A_DIRECTORY
+ // EPERM is converted to os.ErrPermission so it is not a StatusError
+ if err, ok := err.(*StatusError); ok {
+ switch err.Code {
+ case sshFxFailure, sshFxFileIsADirectory:
+ return c.RemoveDirectory(path)
+ }
+ }
+ if os.IsPermission(err) {
+ return c.RemoveDirectory(path)
+ }
+ return err
+}
+
+func (c *Client) removeFile(path string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{
+ ID: id,
+ Filename: path,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// RemoveDirectory removes a directory path.
+func (c *Client) RemoveDirectory(path string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{
+ ID: id,
+ Path: path,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// Rename renames a file.
+func (c *Client) Rename(oldname, newname string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{
+ ID: id,
+ Oldpath: oldname,
+ Newpath: newname,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// PosixRename renames a file using the posix-rename@openssh.com extension
+// which will replace newname if it already exists.
+func (c *Client) PosixRename(oldname, newname string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{
+ ID: id,
+ Oldpath: oldname,
+ Newpath: newname,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// RealPath can be used to have the server canonicalize any given path name to an absolute path.
+//
+// This is useful for converting path names containing ".." components,
+// or relative pathnames without a leading slash into absolute paths.
+func (c *Client) RealPath(path string) (string, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{
+ ID: id,
+ Path: path,
+ })
+ if err != nil {
+ return "", err
+ }
+ switch typ {
+ case sshFxpName:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return "", &unexpectedIDErr{id, sid}
+ }
+ count, data := unmarshalUint32(data)
+ if count != 1 {
+ return "", unexpectedCount(1, count)
+ }
+ filename, _ := unmarshalString(data) // ignore attributes
+ return filename, nil
+ case sshFxpStatus:
+ return "", normaliseError(unmarshalStatus(id, data))
+ default:
+ return "", unimplementedPacketErr(typ)
+ }
+}
+
+// Getwd returns the current working directory of the server. Operations
+// involving relative paths will be based at this location.
+func (c *Client) Getwd() (string, error) {
+ return c.RealPath(".")
+}
+
+// Mkdir creates the specified directory. An error will be returned if a file or
+// directory with the specified path already exists, or if the directory's
+// parent folder does not exist (the method cannot create complete paths).
+func (c *Client) Mkdir(path string) error {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{
+ ID: id,
+ Path: path,
+ })
+ if err != nil {
+ return err
+ }
+ switch typ {
+ case sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return unimplementedPacketErr(typ)
+ }
+}
+
+// MkdirAll creates a directory named path, along with any necessary parents,
+// and returns nil, or else returns an error.
+// If path is already a directory, MkdirAll does nothing and returns nil.
+// If path contains a regular file, an error is returned
+func (c *Client) MkdirAll(path string) error {
+ // Most of this code mimics https://golang.org/src/os/path.go?s=514:561#L13
+ // Fast path: if we can tell whether path is a directory or file, stop with success or error.
+ dir, err := c.Stat(path)
+ if err == nil {
+ if dir.IsDir() {
+ return nil
+ }
+ return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR}
+ }
+
+ // Slow path: make sure parent exists and then call Mkdir for path.
+ i := len(path)
+ for i > 0 && path[i-1] == '/' { // Skip trailing path separator.
+ i--
+ }
+
+ j := i
+ for j > 0 && path[j-1] != '/' { // Scan backward over element.
+ j--
+ }
+
+ if j > 1 {
+ // Create parent
+ err = c.MkdirAll(path[0 : j-1])
+ if err != nil {
+ return err
+ }
+ }
+
+ // Parent now exists; invoke Mkdir and use its result.
+ err = c.Mkdir(path)
+ if err != nil {
+ // Handle arguments like "foo/." by
+ // double-checking that directory doesn't exist.
+ dir, err1 := c.Lstat(path)
+ if err1 == nil && dir.IsDir() {
+ return nil
+ }
+ return err
+ }
+ return nil
+}
+
+// File represents a remote file.
+type File struct {
+ c *Client
+ path string
+ handle string
+
+ mu sync.Mutex
+ offset int64 // current offset within remote file
+}
+
+// Close closes the File, rendering it unusable for I/O. It returns an
+// error, if any.
+func (f *File) Close() error {
+ return f.c.close(f.handle)
+}
+
+// Name returns the name of the file as presented to Open or Create.
+func (f *File) Name() string {
+ return f.path
+}
+
+// Read reads up to len(b) bytes from the File. It returns the number of bytes
+// read and an error, if any. Read follows io.Reader semantics, so when Read
+// encounters an error or EOF condition after successfully reading n > 0 bytes,
+// it returns the number of bytes read.
+//
+// To maximise throughput for transferring the entire file (especially
+// over high latency links) it is recommended to use WriteTo rather
+// than calling Read multiple times. io.Copy will do this
+// automatically.
+func (f *File) Read(b []byte) (int, error) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ n, err := f.ReadAt(b, f.offset)
+ f.offset += int64(n)
+ return n, err
+}
+
+// readChunkAt attempts to read the whole entire length of the buffer from the file starting at the offset.
+// It will continue progressively reading into the buffer until it fills the whole buffer, or an error occurs.
+func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) {
+ for err == nil && n < len(b) {
+ id := f.c.nextID()
+ typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{
+ ID: id,
+ Handle: f.handle,
+ Offset: uint64(off) + uint64(n),
+ Len: uint32(len(b) - n),
+ })
+ if err != nil {
+ return n, err
+ }
+
+ switch typ {
+ case sshFxpStatus:
+ return n, normaliseError(unmarshalStatus(id, data))
+
+ case sshFxpData:
+ sid, data := unmarshalUint32(data)
+ if id != sid {
+ return n, &unexpectedIDErr{id, sid}
+ }
+
+ l, data := unmarshalUint32(data)
+ n += copy(b[n:], data[:l])
+
+ default:
+ return n, unimplementedPacketErr(typ)
+ }
+ }
+
+ return
+}
+
+func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
+ for read < len(b) {
+ rb := b[read:]
+ if len(rb) > f.c.maxPacket {
+ rb = rb[:f.c.maxPacket]
+ }
+ n, err := f.readChunkAt(nil, rb, off+int64(read))
+ if n < 0 {
+ panic("sftp.File: returned negative count from readChunkAt")
+ }
+ if n > 0 {
+ read += n
+ }
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ return read, nil // return nil explicitly.
+ }
+ return read, err
+ }
+ }
+ return read, nil
+}
+
+// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns
+// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
+// so the file offset is not altered during the read.
+func (f *File) ReadAt(b []byte, off int64) (int, error) {
+ if len(b) <= f.c.maxPacket {
+ // This should be able to be serviced with 1/2 requests.
+ // So, just do it directly.
+ return f.readChunkAt(nil, b, off)
+ }
+
+ if f.c.disableConcurrentReads {
+ return f.readAtSequential(b, off)
+ }
+
+ // Split the read into multiple maxPacket-sized concurrent reads bounded by maxConcurrentRequests.
+ // This allows writes with a suitably large buffer to transfer data at a much faster rate
+ // by overlapping round trip times.
+
+ cancel := make(chan struct{})
+
+ concurrency := len(b)/f.c.maxPacket + 1
+ if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
+ concurrency = f.c.maxConcurrentRequests
+ }
+
+ resPool := newResChanPool(concurrency)
+
+ type work struct {
+ id uint32
+ res chan result
+
+ b []byte
+ off int64
+ }
+ workCh := make(chan work)
+
+ // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
+ go func() {
+ defer close(workCh)
+
+ b := b
+ offset := off
+ chunkSize := f.c.maxPacket
+
+ for len(b) > 0 {
+ rb := b
+ if len(rb) > chunkSize {
+ rb = rb[:chunkSize]
+ }
+
+ id := f.c.nextID()
+ res := resPool.Get()
+
+ f.c.dispatchRequest(res, &sshFxpReadPacket{
+ ID: id,
+ Handle: f.handle,
+ Offset: uint64(offset),
+ Len: uint32(chunkSize),
+ })
+
+ select {
+ case workCh <- work{id, res, rb, offset}:
+ case <-cancel:
+ return
+ }
+
+ offset += int64(len(rb))
+ b = b[len(rb):]
+ }
+ }()
+
+ type rErr struct {
+ off int64
+ err error
+ }
+ errCh := make(chan rErr)
+
+ var wg sync.WaitGroup
+ wg.Add(concurrency)
+ for i := 0; i < concurrency; i++ {
+ // Map_i: each worker gets work, and then performs the Read into its buffer from its respective offset.
+ go func() {
+ defer wg.Done()
+
+ for packet := range workCh {
+ var n int
+
+ s := <-packet.res
+ resPool.Put(packet.res)
+
+ err := s.err
+ if err == nil {
+ switch s.typ {
+ case sshFxpStatus:
+ err = normaliseError(unmarshalStatus(packet.id, s.data))
+
+ case sshFxpData:
+ sid, data := unmarshalUint32(s.data)
+ if packet.id != sid {
+ err = &unexpectedIDErr{packet.id, sid}
+
+ } else {
+ l, data := unmarshalUint32(data)
+ n = copy(packet.b, data[:l])
+
+ // For normal disk files, it is guaranteed that this will read
+ // the specified number of bytes, or up to end of file.
+ // This implies, if we have a short read, that means EOF.
+ if n < len(packet.b) {
+ err = io.EOF
+ }
+ }
+
+ default:
+ err = unimplementedPacketErr(s.typ)
+ }
+ }
+
+ if err != nil {
+ // return the offset as the start + how much we read before the error.
+ errCh <- rErr{packet.off + int64(n), err}
+ return
+ }
+ }
+ }()
+ }
+
+ // Wait for long tail, before closing results.
+ go func() {
+ wg.Wait()
+ close(errCh)
+ }()
+
+ // Reduce: collect all the results into a relevant return: the earliest offset to return an error.
+ firstErr := rErr{math.MaxInt64, nil}
+ for rErr := range errCh {
+ if rErr.off <= firstErr.off {
+ firstErr = rErr
+ }
+
+ select {
+ case <-cancel:
+ default:
+ // stop any more work from being distributed. (Just in case.)
+ close(cancel)
+ }
+ }
+
+ if firstErr.err != nil {
+ // firstErr.err != nil if and only if firstErr.off > our starting offset.
+ return int(firstErr.off - off), firstErr.err
+ }
+
+ // As per spec for io.ReaderAt, we return nil error if and only if we read everything.
+ return len(b), nil
+}
+
+// writeToSequential implements WriteTo, but works sequentially with no parallelism.
+func (f *File) writeToSequential(w io.Writer) (written int64, err error) {
+ b := make([]byte, f.c.maxPacket)
+ ch := make(chan result, 1) // reusable channel
+
+ for {
+ n, err := f.readChunkAt(ch, b, f.offset)
+ if n < 0 {
+ panic("sftp.File: returned negative count from readChunkAt")
+ }
+
+ if n > 0 {
+ f.offset += int64(n)
+
+ m, err2 := w.Write(b[:n])
+ written += int64(m)
+
+ if err == nil {
+ err = err2
+ }
+ }
+
+ if err != nil {
+ if err == io.EOF {
+ return written, nil // return nil explicitly.
+ }
+
+ return written, err
+ }
+ }
+}
+
+// WriteTo writes the file to the given Writer.
+// The return value is the number of bytes written.
+// Any error encountered during the write is also returned.
+//
+// This method is preferred over calling Read multiple times
+// to maximise throughput for transferring the entire file,
+// especially over high latency links.
+func (f *File) WriteTo(w io.Writer) (written int64, err error) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if f.c.disableConcurrentReads {
+ return f.writeToSequential(w)
+ }
+
+ // For concurrency, we want to guess how many concurrent workers we should use.
+ var fileStat *FileStat
+ if f.c.useFstat {
+ fileStat, err = f.c.fstat(f.handle)
+ } else {
+ fileStat, err = f.c.stat(f.path)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ fileSize := fileStat.Size
+ if fileSize <= uint64(f.c.maxPacket) || !isRegular(fileStat.Mode) {
+ // only regular files are guaranteed to return (full read) xor (partial read, next error)
+ return f.writeToSequential(w)
+ }
+
+ concurrency64 := fileSize/uint64(f.c.maxPacket) + 1 // a bad guess, but better than no guess
+ if concurrency64 > uint64(f.c.maxConcurrentRequests) || concurrency64 < 1 {
+ concurrency64 = uint64(f.c.maxConcurrentRequests)
+ }
+ // Now that concurrency64 is saturated to an int value, we know this assignment cannot possibly overflow.
+ concurrency := int(concurrency64)
+
+ chunkSize := f.c.maxPacket
+ pool := newBufPool(concurrency, chunkSize)
+ resPool := newResChanPool(concurrency)
+
+ cancel := make(chan struct{})
+ var wg sync.WaitGroup
+ defer func() {
+ // Once the writing Reduce phase has ended, all the feed work needs to unconditionally stop.
+ close(cancel)
+
+ // We want to wait until all outstanding goroutines with an `f` or `f.c` reference have completed.
+ // Just to be sure we don’t orphan any goroutines any hanging references.
+ wg.Wait()
+ }()
+
+ type writeWork struct {
+ b []byte
+ off int64
+ err error
+
+ next chan writeWork
+ }
+ writeCh := make(chan writeWork)
+
+ type readWork struct {
+ id uint32
+ res chan result
+ off int64
+
+ cur, next chan writeWork
+ }
+ readCh := make(chan readWork)
+
+ // Slice: hand out chunks of work on demand, with a `cur` and `next` channel built-in for sequencing.
+ go func() {
+ defer close(readCh)
+
+ off := f.offset
+
+ cur := writeCh
+ for {
+ id := f.c.nextID()
+ res := resPool.Get()
+
+ next := make(chan writeWork)
+ readWork := readWork{
+ id: id,
+ res: res,
+ off: off,
+
+ cur: cur,
+ next: next,
+ }
+
+ f.c.dispatchRequest(res, &sshFxpReadPacket{
+ ID: id,
+ Handle: f.handle,
+ Offset: uint64(off),
+ Len: uint32(chunkSize),
+ })
+
+ select {
+ case readCh <- readWork:
+ case <-cancel:
+ return
+ }
+
+ off += int64(chunkSize)
+ cur = next
+ }
+ }()
+
+ wg.Add(concurrency)
+ for i := 0; i < concurrency; i++ {
+ // Map_i: each worker gets readWork, and does the Read into a buffer at the given offset.
+ go func() {
+ defer wg.Done()
+
+ for readWork := range readCh {
+ var b []byte
+ var n int
+
+ s := <-readWork.res
+ resPool.Put(readWork.res)
+
+ err := s.err
+ if err == nil {
+ switch s.typ {
+ case sshFxpStatus:
+ err = normaliseError(unmarshalStatus(readWork.id, s.data))
+
+ case sshFxpData:
+ sid, data := unmarshalUint32(s.data)
+ if readWork.id != sid {
+ err = &unexpectedIDErr{readWork.id, sid}
+
+ } else {
+ l, data := unmarshalUint32(data)
+ b = pool.Get()[:l]
+ n = copy(b, data[:l])
+ b = b[:n]
+ }
+
+ default:
+ err = unimplementedPacketErr(s.typ)
+ }
+ }
+
+ writeWork := writeWork{
+ b: b,
+ off: readWork.off,
+ err: err,
+
+ next: readWork.next,
+ }
+
+ select {
+ case readWork.cur <- writeWork:
+ case <-cancel:
+ return
+ }
+
+ if err != nil {
+ return
+ }
+ }
+ }()
+ }
+
+ // Reduce: serialize the results from the reads into sequential writes.
+ cur := writeCh
+ for {
+ packet, ok := <-cur
+ if !ok {
+ return written, errors.New("sftp.File.WriteTo: unexpectedly closed channel")
+ }
+
+ // Because writes are serialized, this will always be the last successfully read byte.
+ f.offset = packet.off + int64(len(packet.b))
+
+ if len(packet.b) > 0 {
+ n, err := w.Write(packet.b)
+ written += int64(n)
+ if err != nil {
+ return written, err
+ }
+ }
+
+ if packet.err != nil {
+ if packet.err == io.EOF {
+ return written, nil
+ }
+
+ return written, packet.err
+ }
+
+ pool.Put(packet.b)
+ cur = packet.next
+ }
+}
+
+// Stat returns the FileInfo structure describing file. If there is an
+// error.
+func (f *File) Stat() (os.FileInfo, error) {
+ fs, err := f.c.fstat(f.handle)
+ if err != nil {
+ return nil, err
+ }
+ return fileInfoFromStat(fs, path.Base(f.path)), nil
+}
+
+// Write writes len(b) bytes to the File. It returns the number of bytes
+// written and an error, if any. Write returns a non-nil error when n !=
+// len(b).
+//
+// To maximise throughput for transferring the entire file (especially
+// over high latency links) it is recommended to use ReadFrom rather
+// than calling Write multiple times. io.Copy will do this
+// automatically.
+func (f *File) Write(b []byte) (int, error) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ n, err := f.WriteAt(b, f.offset)
+ f.offset += int64(n)
+ return n, err
+}
+
+func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) {
+ typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{
+ ID: f.c.nextID(),
+ Handle: f.handle,
+ Offset: uint64(off),
+ Length: uint32(len(b)),
+ Data: b,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ switch typ {
+ case sshFxpStatus:
+ id, _ := unmarshalUint32(data)
+ err := normaliseError(unmarshalStatus(id, data))
+ if err != nil {
+ return 0, err
+ }
+
+ default:
+ return 0, unimplementedPacketErr(typ)
+ }
+
+ return len(b), nil
+}
+
+// writeAtConcurrent implements WriterAt, but works concurrently rather than sequentially.
+func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
+ // Split the write into multiple maxPacket sized concurrent writes
+ // bounded by maxConcurrentRequests. This allows writes with a suitably
+ // large buffer to transfer data at a much faster rate due to
+ // overlapping round trip times.
+
+ cancel := make(chan struct{})
+
+ type work struct {
+ b []byte
+ off int64
+ }
+ workCh := make(chan work)
+
+ // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
+ go func() {
+ defer close(workCh)
+
+ var read int
+ chunkSize := f.c.maxPacket
+
+ for read < len(b) {
+ wb := b[read:]
+ if len(wb) > chunkSize {
+ wb = wb[:chunkSize]
+ }
+
+ select {
+ case workCh <- work{wb, off + int64(read)}:
+ case <-cancel:
+ return
+ }
+
+ read += len(wb)
+ }
+ }()
+
+ type wErr struct {
+ off int64
+ err error
+ }
+ errCh := make(chan wErr)
+
+ concurrency := len(b)/f.c.maxPacket + 1
+ if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
+ concurrency = f.c.maxConcurrentRequests
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(concurrency)
+ for i := 0; i < concurrency; i++ {
+ // Map_i: each worker gets work, and does the Write from each buffer to its respective offset.
+ go func() {
+ defer wg.Done()
+
+ ch := make(chan result, 1) // reusable channel per mapper.
+
+ for packet := range workCh {
+ n, err := f.writeChunkAt(ch, packet.b, packet.off)
+ if err != nil {
+ // return the offset as the start + how much we wrote before the error.
+ errCh <- wErr{packet.off + int64(n), err}
+ }
+ }
+ }()
+ }
+
+ // Wait for long tail, before closing results.
+ go func() {
+ wg.Wait()
+ close(errCh)
+ }()
+
+ // Reduce: collect all the results into a relevant return: the earliest offset to return an error.
+ firstErr := wErr{math.MaxInt64, nil}
+ for wErr := range errCh {
+ if wErr.off <= firstErr.off {
+ firstErr = wErr
+ }
+
+ select {
+ case <-cancel:
+ default:
+ // stop any more work from being distributed. (Just in case.)
+ close(cancel)
+ }
+ }
+
+ if firstErr.err != nil {
+ // firstErr.err != nil if and only if firstErr.off >= our starting offset.
+ return int(firstErr.off - off), firstErr.err
+ }
+
+ return len(b), nil
+}
+
+// WriteAt writes up to len(b) byte to the File at a given offset `off`. It returns
+// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
+// so the file offset is not altered during the write.
+func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
+ if len(b) <= f.c.maxPacket {
+ // We can do this in one write.
+ return f.writeChunkAt(nil, b, off)
+ }
+
+ if f.c.useConcurrentWrites {
+ return f.writeAtConcurrent(b, off)
+ }
+
+ ch := make(chan result, 1) // reusable channel
+
+ chunkSize := f.c.maxPacket
+
+ for written < len(b) {
+ wb := b[written:]
+ if len(wb) > chunkSize {
+ wb = wb[:chunkSize]
+ }
+
+ n, err := f.writeChunkAt(ch, wb, off+int64(written))
+ if n > 0 {
+ written += n
+ }
+
+ if err != nil {
+ return written, err
+ }
+ }
+
+ return len(b), nil
+}
+
+// ReadFromWithConcurrency implements ReaderFrom,
+// but uses the given concurrency to issue multiple requests at the same time.
+//
+// Giving a concurrency of less than one will default to the Client’s max concurrency.
+//
+// Otherwise, the given concurrency will be capped by the Client's max concurrency.
+func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
+ // Split the write into multiple maxPacket sized concurrent writes.
+ // This allows writes with a suitably large reader
+ // to transfer data at a much faster rate due to overlapping round trip times.
+
+ cancel := make(chan struct{})
+
+ type work struct {
+ b []byte
+ n int
+ off int64
+ }
+ workCh := make(chan work)
+
+ type rwErr struct {
+ off int64
+ err error
+ }
+ errCh := make(chan rwErr)
+
+ if concurrency > f.c.maxConcurrentRequests || concurrency < 1 {
+ concurrency = f.c.maxConcurrentRequests
+ }
+
+ pool := newBufPool(concurrency, f.c.maxPacket)
+
+ // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets.
+ go func() {
+ defer close(workCh)
+
+ off := f.offset
+
+ for {
+ b := pool.Get()
+
+ n, err := r.Read(b)
+ if n > 0 {
+ read += int64(n)
+
+ select {
+ case workCh <- work{b, n, off}:
+ // We need the pool.Put(b) to put the whole slice, not just trunced.
+ case <-cancel:
+ return
+ }
+
+ off += int64(n)
+ }
+
+ if err != nil {
+ if err != io.EOF {
+ errCh <- rwErr{off, err}
+ }
+ return
+ }
+ }
+ }()
+
+ var wg sync.WaitGroup
+ wg.Add(concurrency)
+ for i := 0; i < concurrency; i++ {
+ // Map_i: each worker gets work, and does the Write from each buffer to its respective offset.
+ go func() {
+ defer wg.Done()
+
+ ch := make(chan result, 1) // reusable channel per mapper.
+
+ for packet := range workCh {
+ n, err := f.writeChunkAt(ch, packet.b[:packet.n], packet.off)
+ if err != nil {
+ // return the offset as the start + how much we wrote before the error.
+ errCh <- rwErr{packet.off + int64(n), err}
+ }
+ pool.Put(packet.b)
+ }
+ }()
+ }
+
+ // Wait for long tail, before closing results.
+ go func() {
+ wg.Wait()
+ close(errCh)
+ }()
+
+ // Reduce: Collect all the results into a relevant return: the earliest offset to return an error.
+ firstErr := rwErr{math.MaxInt64, nil}
+ for rwErr := range errCh {
+ if rwErr.off <= firstErr.off {
+ firstErr = rwErr
+ }
+
+ select {
+ case <-cancel:
+ default:
+ // stop any more work from being distributed.
+ close(cancel)
+ }
+ }
+
+ if firstErr.err != nil {
+ // firstErr.err != nil if and only if firstErr.off is a valid offset.
+ //
+ // firstErr.off will then be the lesser of:
+ // * the offset of the first error from writing,
+ // * the last successfully read offset.
+ //
+ // This could be less than the last successfully written offset,
+ // which is the whole reason for the UseConcurrentWrites() ClientOption.
+ //
+ // Callers are responsible for truncating any SFTP files to a safe length.
+ f.offset = firstErr.off
+
+ // ReadFrom is defined to return the read bytes, regardless of any writer errors.
+ return read, firstErr.err
+ }
+
+ f.offset += read
+ return read, nil
+}
+
+// ReadFrom reads data from r until EOF and writes it to the file. The return
+// value is the number of bytes read. Any error except io.EOF encountered
+// during the read is also returned.
+//
+// This method is preferred over calling Write multiple times
+// to maximise throughput for transferring the entire file,
+// especially over high-latency links.
+func (f *File) ReadFrom(r io.Reader) (int64, error) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ if f.c.useConcurrentWrites {
+ var remain int64
+ switch r := r.(type) {
+ case interface{ Len() int }:
+ remain = int64(r.Len())
+
+ case interface{ Size() int64 }:
+ remain = r.Size()
+
+ case *io.LimitedReader:
+ remain = r.N
+
+ case interface{ Stat() (os.FileInfo, error) }:
+ info, err := r.Stat()
+ if err == nil {
+ remain = info.Size()
+ }
+ }
+
+ if remain < 0 {
+ // We can strongly assert that we want default max concurrency here.
+ return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests)
+ }
+
+ if remain > int64(f.c.maxPacket) {
+ // Otherwise, only use concurrency, if it would be at least two packets.
+
+ // This is the best reasonable guess we can make.
+ concurrency64 := remain/int64(f.c.maxPacket) + 1
+
+ // We need to cap this value to an `int` size value to avoid overflow on 32-bit machines.
+ // So, we may as well pre-cap it to `f.c.maxConcurrentRequests`.
+ if concurrency64 > int64(f.c.maxConcurrentRequests) {
+ concurrency64 = int64(f.c.maxConcurrentRequests)
+ }
+
+ return f.ReadFromWithConcurrency(r, int(concurrency64))
+ }
+ }
+
+ ch := make(chan result, 1) // reusable channel
+
+ b := make([]byte, f.c.maxPacket)
+
+ var read int64
+ for {
+ n, err := r.Read(b)
+ if n < 0 {
+ panic("sftp.File: reader returned negative count from Read")
+ }
+
+ if n > 0 {
+ read += int64(n)
+
+ m, err2 := f.writeChunkAt(ch, b[:n], f.offset)
+ f.offset += int64(m)
+
+ if err == nil {
+ err = err2
+ }
+ }
+
+ if err != nil {
+ if err == io.EOF {
+ return read, nil // return nil explicitly.
+ }
+
+ return read, err
+ }
+ }
+}
+
+// Seek implements io.Seeker by setting the client offset for the next Read or
+// Write. It returns the next offset read. Seeking before or after the end of
+// the file is undefined. Seeking relative to the end calls Stat.
+func (f *File) Seek(offset int64, whence int) (int64, error) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ switch whence {
+ case io.SeekStart:
+ case io.SeekCurrent:
+ offset += f.offset
+ case io.SeekEnd:
+ fi, err := f.Stat()
+ if err != nil {
+ return f.offset, err
+ }
+ offset += fi.Size()
+ default:
+ return f.offset, unimplementedSeekWhence(whence)
+ }
+
+ if offset < 0 {
+ return f.offset, os.ErrInvalid
+ }
+
+ f.offset = offset
+ return f.offset, nil
+}
+
+// Chown changes the uid/gid of the current file.
+func (f *File) Chown(uid, gid int) error {
+ return f.c.Chown(f.path, uid, gid)
+}
+
+// Chmod changes the permissions of the current file.
+//
+// See Client.Chmod for details.
+func (f *File) Chmod(mode os.FileMode) error {
+ return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
+}
+
+// Sync requests a flush of the contents of a File to stable storage.
+//
+// Sync requires the server to support the fsync@openssh.com extension.
+func (f *File) Sync() error {
+ id := f.c.nextID()
+ typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{
+ ID: id,
+ Handle: f.handle,
+ })
+
+ switch {
+ case err != nil:
+ return err
+ case typ == sshFxpStatus:
+ return normaliseError(unmarshalStatus(id, data))
+ default:
+ return &unexpectedPacketErr{want: sshFxpStatus, got: typ}
+ }
+}
+
+// Truncate sets the size of the current file. Although it may be safely assumed
+// that if the size is less than its current size it will be truncated to fit,
+// the SFTP protocol does not specify what behavior the server should do when setting
+// size greater than the current size.
+// We send a SSH_FXP_FSETSTAT here since we have a file handle
+func (f *File) Truncate(size int64) error {
+ return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
+}
+
+// normaliseError normalises an error into a more standard form that can be
+// checked against stdlib errors like io.EOF or os.ErrNotExist.
+func normaliseError(err error) error {
+ switch err := err.(type) {
+ case *StatusError:
+ switch err.Code {
+ case sshFxEOF:
+ return io.EOF
+ case sshFxNoSuchFile:
+ return os.ErrNotExist
+ case sshFxPermissionDenied:
+ return os.ErrPermission
+ case sshFxOk:
+ return nil
+ default:
+ return err
+ }
+ default:
+ return err
+ }
+}
+
+// flags converts the flags passed to OpenFile into ssh flags.
+// Unsupported flags are ignored.
+func flags(f int) uint32 {
+ var out uint32
+ switch f & os.O_WRONLY {
+ case os.O_WRONLY:
+ out |= sshFxfWrite
+ case os.O_RDONLY:
+ out |= sshFxfRead
+ }
+ if f&os.O_RDWR == os.O_RDWR {
+ out |= sshFxfRead | sshFxfWrite
+ }
+ if f&os.O_APPEND == os.O_APPEND {
+ out |= sshFxfAppend
+ }
+ if f&os.O_CREATE == os.O_CREATE {
+ out |= sshFxfCreat
+ }
+ if f&os.O_TRUNC == os.O_TRUNC {
+ out |= sshFxfTrunc
+ }
+ if f&os.O_EXCL == os.O_EXCL {
+ out |= sshFxfExcl
+ }
+ return out
+}
+
+// toChmodPerm converts Go permission bits to POSIX permission bits.
+//
+// This differs from fromFileMode in that we preserve the POSIX versions of
+// setuid, setgid and sticky in m, because we've historically supported those
+// bits, and we mask off any non-permission bits.
+func toChmodPerm(m os.FileMode) (perm uint32) {
+ const mask = os.ModePerm | s_ISUID | s_ISGID | s_ISVTX
+ perm = uint32(m & mask)
+
+ if m&os.ModeSetuid != 0 {
+ perm |= s_ISUID
+ }
+ if m&os.ModeSetgid != 0 {
+ perm |= s_ISGID
+ }
+ if m&os.ModeSticky != 0 {
+ perm |= s_ISVTX
+ }
+
+ return perm
+}
diff --git a/sftp/conn.go b/sftp/conn.go
new file mode 100644
index 0000000..7d95142
--- /dev/null
+++ b/sftp/conn.go
@@ -0,0 +1,189 @@
+package sftp
+
+import (
+ "encoding"
+ "fmt"
+ "io"
+ "sync"
+)
+
+// conn implements a bidirectional channel on which client and server
+// connections are multiplexed.
+type conn struct {
+ io.Reader
+ io.WriteCloser
+ // this is the same allocator used in packet manager
+ alloc *allocator
+ sync.Mutex // used to serialise writes to sendPacket
+}
+
+// the orderID is used in server mode if the allocator is enabled.
+// For the client mode just pass 0
+func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
+ return recvPacket(c, c.alloc, orderID)
+}
+
+func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
+ c.Lock()
+ defer c.Unlock()
+
+ return sendPacket(c, m)
+}
+
+func (c *conn) Close() error {
+ c.Lock()
+ defer c.Unlock()
+ return c.WriteCloser.Close()
+}
+
+type clientConn struct {
+ conn
+ wg sync.WaitGroup
+
+ sync.Mutex // protects inflight
+ inflight map[uint32]chan<- result // outstanding requests
+
+ closed chan struct{}
+ err error
+}
+
+// Wait blocks until the conn has shut down, and return the error
+// causing the shutdown. It can be called concurrently from multiple
+// goroutines.
+func (c *clientConn) Wait() error {
+ <-c.closed
+ return c.err
+}
+
+// Close closes the SFTP session.
+func (c *clientConn) Close() error {
+ defer c.wg.Wait()
+ return c.conn.Close()
+}
+
+func (c *clientConn) loop() {
+ defer c.wg.Done()
+ err := c.recv()
+ if err != nil {
+ c.broadcastErr(err)
+ }
+}
+
+// recv continuously reads from the server and forwards responses to the
+// appropriate channel.
+func (c *clientConn) recv() error {
+ defer c.conn.Close()
+
+ for {
+ typ, data, err := c.recvPacket(0)
+ if err != nil {
+ return err
+ }
+ sid, _, err := unmarshalUint32Safe(data)
+ if err != nil {
+ return err
+ }
+
+ ch, ok := c.getChannel(sid)
+ if !ok {
+ // This is an unexpected occurrence. Send the error
+ // back to all listeners so that they terminate
+ // gracefully.
+ return fmt.Errorf("sid not found: %d", sid)
+ }
+
+ ch <- result{typ: typ, data: data}
+ }
+}
+
+func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool {
+ c.Lock()
+ defer c.Unlock()
+
+ select {
+ case <-c.closed:
+ // already closed with broadcastErr, return error on chan.
+ ch <- result{err: ErrSSHFxConnectionLost}
+ return false
+ default:
+ }
+
+ c.inflight[sid] = ch
+ return true
+}
+
+func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
+ c.Lock()
+ defer c.Unlock()
+
+ ch, ok := c.inflight[sid]
+ delete(c.inflight, sid)
+
+ return ch, ok
+}
+
+// result captures the result of receiving the a packet from the server
+type result struct {
+ typ byte
+ data []byte
+ err error
+}
+
+type idmarshaler interface {
+ id() uint32
+ encoding.BinaryMarshaler
+}
+
+func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
+ if cap(ch) < 1 {
+ ch = make(chan result, 1)
+ }
+
+ c.dispatchRequest(ch, p)
+ s := <-ch
+ return s.typ, s.data, s.err
+}
+
+// dispatchRequest should ideally only be called by race-detection tests outside of this file,
+// where you have to ensure two packets are in flight sequentially after each other.
+func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
+ sid := p.id()
+
+ if !c.putChannel(ch, sid) {
+ // already closed.
+ return
+ }
+
+ if err := c.conn.sendPacket(p); err != nil {
+ if ch, ok := c.getChannel(sid); ok {
+ ch <- result{err: err}
+ }
+ }
+}
+
+// broadcastErr sends an error to all goroutines waiting for a response.
+func (c *clientConn) broadcastErr(err error) {
+ c.Lock()
+ defer c.Unlock()
+
+ bcastRes := result{err: ErrSSHFxConnectionLost}
+ for sid, ch := range c.inflight {
+ ch <- bcastRes
+
+ // Replace the chan in inflight,
+ // we have hijacked this chan,
+ // and this guarantees always-only-once sending.
+ c.inflight[sid] = make(chan<- result, 1)
+ }
+
+ c.err = err
+ close(c.closed)
+}
+
+type serverConn struct {
+ conn
+}
+
+func (s *serverConn) sendError(id uint32, err error) error {
+ return s.sendPacket(statusFromError(id, err))
+}
diff --git a/sftp/debug.go b/sftp/debug.go
new file mode 100644
index 0000000..3e264ab
--- /dev/null
+++ b/sftp/debug.go
@@ -0,0 +1,9 @@
+// +build debug
+
+package sftp
+
+import "log"
+
+func debug(fmt string, args ...interface{}) {
+ log.Printf(fmt, args...)
+}
diff --git a/sftp/ls_formatting.go b/sftp/ls_formatting.go
new file mode 100644
index 0000000..9c59070
--- /dev/null
+++ b/sftp/ls_formatting.go
@@ -0,0 +1,81 @@
+package sftp
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "os/user"
+ "strconv"
+ "time"
+
+ sshfx "git.deuxfleurs.fr/Deuxfleurs/bagage/internal/encoding/ssh/filexfer"
+)
+
+func lsFormatID(id uint32) string {
+ return strconv.FormatUint(uint64(id), 10)
+}
+
+type osIDLookup struct{}
+
+func (osIDLookup) Filelist(*Request) (ListerAt, error) {
+ return nil, errors.New("unimplemented stub")
+}
+
+func (osIDLookup) LookupUserName(uid string) string {
+ u, err := user.LookupId(uid)
+ if err != nil {
+ return uid
+ }
+
+ return u.Username
+}
+
+func (osIDLookup) LookupGroupName(gid string) string {
+ g, err := user.LookupGroupId(gid)
+ if err != nil {
+ return gid
+ }
+
+ return g.Name
+}
+
+// runLs formats the FileInfo as per `ls -l` style, which is in the 'longname' field of a SSH_FXP_NAME entry.
+// This is a fairly simple implementation, just enough to look close to openssh in simple cases.
+func runLs(idLookup NameLookupFileLister, dirent os.FileInfo) string {
+ // example from openssh sftp server:
+ // crw-rw-rw- 1 root wheel 0 Jul 31 20:52 ttyvd
+ // format:
+ // {directory / char device / etc}{rwxrwxrwx} {number of links} owner group size month day [time (this year) | year (otherwise)] name
+
+ symPerms := sshfx.FileMode(fromFileMode(dirent.Mode())).String()
+
+ var numLinks uint64 = 1
+ uid, gid := "0", "0"
+
+ switch sys := dirent.Sys().(type) {
+ case *sshfx.Attributes:
+ uid = lsFormatID(sys.UID)
+ gid = lsFormatID(sys.GID)
+ case *FileStat:
+ uid = lsFormatID(sys.UID)
+ gid = lsFormatID(sys.GID)
+ default:
+ numLinks, uid, gid = lsLinksUIDGID(dirent)
+ }
+
+ if idLookup != nil {
+ uid, gid = idLookup.LookupUserName(uid), idLookup.LookupGroupName(gid)
+ }
+
+ mtime := dirent.ModTime()
+ date := mtime.Format("Jan 2")
+
+ var yearOrTime string
+ if mtime.Before(time.Now().AddDate(0, -6, 0)) {
+ yearOrTime = mtime.Format("2006")
+ } else {
+ yearOrTime = mtime.Format("15:04")
+ }
+
+ return fmt.Sprintf("%s %4d %-8s %-8s %8d %s %5s %s", symPerms, numLinks, uid, gid, dirent.Size(), date, yearOrTime, dirent.Name())
+}
diff --git a/sftp/ls_plan9.go b/sftp/ls_plan9.go
new file mode 100644
index 0000000..a16a3ea
--- /dev/null
+++ b/sftp/ls_plan9.go
@@ -0,0 +1,21 @@
+// +build plan9
+
+package sftp
+
+import (
+ "os"
+ "syscall"
+)
+
+func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) {
+ numLinks = 1
+ uid, gid = "0", "0"
+
+ switch sys := fi.Sys().(type) {
+ case *syscall.Dir:
+ uid = sys.Uid
+ gid = sys.Gid
+ }
+
+ return numLinks, uid, gid
+}
diff --git a/sftp/ls_stub.go b/sftp/ls_stub.go
new file mode 100644
index 0000000..6dec393
--- /dev/null
+++ b/sftp/ls_stub.go
@@ -0,0 +1,11 @@
+// +build windows android
+
+package sftp
+
+import (
+ "os"
+)
+
+func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) {
+ return 1, "0", "0"
+}
diff --git a/sftp/ls_unix.go b/sftp/ls_unix.go
new file mode 100644
index 0000000..59ccffd
--- /dev/null
+++ b/sftp/ls_unix.go
@@ -0,0 +1,23 @@
+// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js
+
+package sftp
+
+import (
+ "os"
+ "syscall"
+)
+
+func lsLinksUIDGID(fi os.FileInfo) (numLinks uint64, uid, gid string) {
+ numLinks = 1
+ uid, gid = "0", "0"
+
+ switch sys := fi.Sys().(type) {
+ case *syscall.Stat_t:
+ numLinks = uint64(sys.Nlink)
+ uid = lsFormatID(sys.Uid)
+ gid = lsFormatID(sys.Gid)
+ default:
+ }
+
+ return numLinks, uid, gid
+}
diff --git a/sftp/packet.go b/sftp/packet.go
new file mode 100644
index 0000000..50ca069
--- /dev/null
+++ b/sftp/packet.go
@@ -0,0 +1,1276 @@
+package sftp
+
+import (
+ "bytes"
+ "encoding"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+)
+
+var (
+ errLongPacket = errors.New("packet too long")
+ errShortPacket = errors.New("packet too short")
+ errUnknownExtendedPacket = errors.New("unknown extended packet")
+)
+
+const (
+ maxMsgLength = 256 * 1024
+ debugDumpTxPacket = false
+ debugDumpRxPacket = false
+ debugDumpTxPacketBytes = false
+ debugDumpRxPacketBytes = false
+)
+
+func marshalUint32(b []byte, v uint32) []byte {
+ return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
+
+func marshalUint64(b []byte, v uint64) []byte {
+ return marshalUint32(marshalUint32(b, uint32(v>>32)), uint32(v))
+}
+
+func marshalString(b []byte, v string) []byte {
+ return append(marshalUint32(b, uint32(len(v))), v...)
+}
+
+func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
+ // attributes variable struct, and also variable per protocol version
+ // spec version 3 attributes:
+ // uint32 flags
+ // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE
+ // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID
+ // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID
+ // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS
+ // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME
+ // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME
+ // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED
+ // string extended_type
+ // string extended_data
+ // ... more extended data (extended_type - extended_data pairs),
+ // so that number of pairs equals extended_count
+
+ flags, fileStat := fileStatFromInfo(fi)
+
+ b = marshalUint32(b, flags)
+ if flags&sshFileXferAttrSize != 0 {
+ b = marshalUint64(b, fileStat.Size)
+ }
+ if flags&sshFileXferAttrUIDGID != 0 {
+ b = marshalUint32(b, fileStat.UID)
+ b = marshalUint32(b, fileStat.GID)
+ }
+ if flags&sshFileXferAttrPermissions != 0 {
+ b = marshalUint32(b, fileStat.Mode)
+ }
+ if flags&sshFileXferAttrACmodTime != 0 {
+ b = marshalUint32(b, fileStat.Atime)
+ b = marshalUint32(b, fileStat.Mtime)
+ }
+
+ return b
+}
+
+func marshalStatus(b []byte, err StatusError) []byte {
+ b = marshalUint32(b, err.Code)
+ b = marshalString(b, err.msg)
+ b = marshalString(b, err.lang)
+ return b
+}
+
+func marshal(b []byte, v interface{}) []byte {
+ if v == nil {
+ return b
+ }
+ switch v := v.(type) {
+ case uint8:
+ return append(b, v)
+ case uint32:
+ return marshalUint32(b, v)
+ case uint64:
+ return marshalUint64(b, v)
+ case string:
+ return marshalString(b, v)
+ case os.FileInfo:
+ return marshalFileInfo(b, v)
+ default:
+ switch d := reflect.ValueOf(v); d.Kind() {
+ case reflect.Struct:
+ for i, n := 0, d.NumField(); i < n; i++ {
+ b = marshal(b, d.Field(i).Interface())
+ }
+ return b
+ case reflect.Slice:
+ for i, n := 0, d.Len(); i < n; i++ {
+ b = marshal(b, d.Index(i).Interface())
+ }
+ return b
+ default:
+ panic(fmt.Sprintf("marshal(%#v): cannot handle type %T", v, v))
+ }
+ }
+}
+
+func unmarshalUint32(b []byte) (uint32, []byte) {
+ v := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
+ return v, b[4:]
+}
+
+func unmarshalUint32Safe(b []byte) (uint32, []byte, error) {
+ var v uint32
+ if len(b) < 4 {
+ return 0, nil, errShortPacket
+ }
+ v, b = unmarshalUint32(b)
+ return v, b, nil
+}
+
+func unmarshalUint64(b []byte) (uint64, []byte) {
+ h, b := unmarshalUint32(b)
+ l, b := unmarshalUint32(b)
+ return uint64(h)<<32 | uint64(l), b
+}
+
+func unmarshalUint64Safe(b []byte) (uint64, []byte, error) {
+ var v uint64
+ if len(b) < 8 {
+ return 0, nil, errShortPacket
+ }
+ v, b = unmarshalUint64(b)
+ return v, b, nil
+}
+
+func unmarshalString(b []byte) (string, []byte) {
+ n, b := unmarshalUint32(b)
+ return string(b[:n]), b[n:]
+}
+
+func unmarshalStringSafe(b []byte) (string, []byte, error) {
+ n, b, err := unmarshalUint32Safe(b)
+ if err != nil {
+ return "", nil, err
+ }
+ if int64(n) > int64(len(b)) {
+ return "", nil, errShortPacket
+ }
+ return string(b[:n]), b[n:], nil
+}
+
+func unmarshalAttrs(b []byte) (*FileStat, []byte) {
+ flags, b := unmarshalUint32(b)
+ return unmarshalFileStat(flags, b)
+}
+
+func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
+ var fs FileStat
+ if flags&sshFileXferAttrSize == sshFileXferAttrSize {
+ fs.Size, b, _ = unmarshalUint64Safe(b)
+ }
+ if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
+ fs.UID, b, _ = unmarshalUint32Safe(b)
+ }
+ if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
+ fs.GID, b, _ = unmarshalUint32Safe(b)
+ }
+ if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
+ fs.Mode, b, _ = unmarshalUint32Safe(b)
+ }
+ if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
+ fs.Atime, b, _ = unmarshalUint32Safe(b)
+ fs.Mtime, b, _ = unmarshalUint32Safe(b)
+ }
+ if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
+ var count uint32
+ count, b, _ = unmarshalUint32Safe(b)
+ ext := make([]StatExtended, count)
+ for i := uint32(0); i < count; i++ {
+ var typ string
+ var data string
+ typ, b, _ = unmarshalStringSafe(b)
+ data, b, _ = unmarshalStringSafe(b)
+ ext[i] = StatExtended{
+ ExtType: typ,
+ ExtData: data,
+ }
+ }
+ fs.Extended = ext
+ }
+ return &fs, b
+}
+
+func unmarshalStatus(id uint32, data []byte) error {
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return &unexpectedIDErr{id, sid}
+ }
+ code, data := unmarshalUint32(data)
+ msg, data, _ := unmarshalStringSafe(data)
+ lang, _, _ := unmarshalStringSafe(data)
+ return &StatusError{
+ Code: code,
+ msg: msg,
+ lang: lang,
+ }
+}
+
+type packetMarshaler interface {
+ marshalPacket() (header, payload []byte, err error)
+}
+
+func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err error) {
+ if m, ok := m.(packetMarshaler); ok {
+ return m.marshalPacket()
+ }
+
+ header, err = m.MarshalBinary()
+ return
+}
+
+// sendPacket marshals p according to RFC 4234.
+func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error {
+ header, payload, err := marshalPacket(m)
+ if err != nil {
+ return fmt.Errorf("binary marshaller failed: %w", err)
+ }
+
+ length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start
+ if debugDumpTxPacketBytes {
+ debug("send packet: %s %d bytes %x%x", fxp(header[4]), length, header[5:], payload)
+ } else if debugDumpTxPacket {
+ debug("send packet: %s %d bytes", fxp(header[4]), length)
+ }
+
+ binary.BigEndian.PutUint32(header[:4], uint32(length))
+
+ if _, err := w.Write(header); err != nil {
+ return fmt.Errorf("failed to send packet: %w", err)
+ }
+
+ if len(payload) > 0 {
+ if _, err := w.Write(payload); err != nil {
+ return fmt.Errorf("failed to send packet payload: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) {
+ var b []byte
+ if alloc != nil {
+ b = alloc.GetPage(orderID)
+ } else {
+ b = make([]byte, 4)
+ }
+ if _, err := io.ReadFull(r, b[:4]); err != nil {
+ return 0, nil, err
+ }
+ length, _ := unmarshalUint32(b)
+ if length > maxMsgLength {
+ debug("recv packet %d bytes too long", length)
+ return 0, nil, errLongPacket
+ }
+ if length == 0 {
+ debug("recv packet of 0 bytes too short")
+ return 0, nil, errShortPacket
+ }
+ if alloc == nil {
+ b = make([]byte, length)
+ }
+ if _, err := io.ReadFull(r, b[:length]); err != nil {
+ debug("recv packet %d bytes: err %v", length, err)
+ return 0, nil, err
+ }
+ if debugDumpRxPacketBytes {
+ debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length])
+ } else if debugDumpRxPacket {
+ debug("recv packet: %s %d bytes", fxp(b[0]), length)
+ }
+ return b[0], b[1:length], nil
+}
+
+type extensionPair struct {
+ Name string
+ Data string
+}
+
+func unmarshalExtensionPair(b []byte) (extensionPair, []byte, error) {
+ var ep extensionPair
+ var err error
+ ep.Name, b, err = unmarshalStringSafe(b)
+ if err != nil {
+ return ep, b, err
+ }
+ ep.Data, b, err = unmarshalStringSafe(b)
+ return ep, b, err
+}
+
+// Here starts the definition of packets along with their MarshalBinary
+// implementations.
+// Manually writing the marshalling logic wins us a lot of time and
+// allocation.
+
+type sshFxInitPacket struct {
+ Version uint32
+ Extensions []extensionPair
+}
+
+func (p *sshFxInitPacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version)
+ for _, e := range p.Extensions {
+ l += 4 + len(e.Name) + 4 + len(e.Data)
+ }
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpInit)
+ b = marshalUint32(b, p.Version)
+
+ for _, e := range p.Extensions {
+ b = marshalString(b, e.Name)
+ b = marshalString(b, e.Data)
+ }
+
+ return b, nil
+}
+
+func (p *sshFxInitPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.Version, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ }
+ for len(b) > 0 {
+ var ep extensionPair
+ ep, b, err = unmarshalExtensionPair(b)
+ if err != nil {
+ return err
+ }
+ p.Extensions = append(p.Extensions, ep)
+ }
+ return nil
+}
+
+type sshFxVersionPacket struct {
+ Version uint32
+ Extensions []sshExtensionPair
+}
+
+type sshExtensionPair struct {
+ Name, Data string
+}
+
+func (p *sshFxVersionPacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version)
+ for _, e := range p.Extensions {
+ l += 4 + len(e.Name) + 4 + len(e.Data)
+ }
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpVersion)
+ b = marshalUint32(b, p.Version)
+
+ for _, e := range p.Extensions {
+ b = marshalString(b, e.Name)
+ b = marshalString(b, e.Data)
+ }
+
+ return b, nil
+}
+
+func marshalIDStringPacket(packetType byte, id uint32, str string) ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(str)
+
+ b := make([]byte, 4, l)
+ b = append(b, packetType)
+ b = marshalUint32(b, id)
+ b = marshalString(b, str)
+
+ return b, nil
+}
+
+func unmarshalIDString(b []byte, id *uint32, str *string) error {
+ var err error
+ *id, b, err = unmarshalUint32Safe(b)
+ if err != nil {
+ return err
+ }
+ *str, _, err = unmarshalStringSafe(b)
+ return err
+}
+
+type sshFxpReaddirPacket struct {
+ ID uint32
+ Handle string
+}
+
+func (p *sshFxpReaddirPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpReaddirPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpReaddir, p.ID, p.Handle)
+}
+
+func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Handle)
+}
+
+type sshFxpOpendirPacket struct {
+ ID uint32
+ Path string
+}
+
+func (p *sshFxpOpendirPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpOpendirPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpOpendir, p.ID, p.Path)
+}
+
+func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Path)
+}
+
+type sshFxpLstatPacket struct {
+ ID uint32
+ Path string
+}
+
+func (p *sshFxpLstatPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpLstatPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpLstat, p.ID, p.Path)
+}
+
+func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Path)
+}
+
+type sshFxpStatPacket struct {
+ ID uint32
+ Path string
+}
+
+func (p *sshFxpStatPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpStatPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpStat, p.ID, p.Path)
+}
+
+func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Path)
+}
+
+type sshFxpFstatPacket struct {
+ ID uint32
+ Handle string
+}
+
+func (p *sshFxpFstatPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpFstatPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpFstat, p.ID, p.Handle)
+}
+
+func (p *sshFxpFstatPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Handle)
+}
+
+type sshFxpClosePacket struct {
+ ID uint32
+ Handle string
+}
+
+func (p *sshFxpClosePacket) id() uint32 { return p.ID }
+
+func (p *sshFxpClosePacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpClose, p.ID, p.Handle)
+}
+
+func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Handle)
+}
+
+type sshFxpRemovePacket struct {
+ ID uint32
+ Filename string
+}
+
+func (p *sshFxpRemovePacket) id() uint32 { return p.ID }
+
+func (p *sshFxpRemovePacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpRemove, p.ID, p.Filename)
+}
+
+func (p *sshFxpRemovePacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Filename)
+}
+
+type sshFxpRmdirPacket struct {
+ ID uint32
+ Path string
+}
+
+func (p *sshFxpRmdirPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpRmdirPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpRmdir, p.ID, p.Path)
+}
+
+func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Path)
+}
+
+type sshFxpSymlinkPacket struct {
+ ID uint32
+ Targetpath string
+ Linkpath string
+}
+
+func (p *sshFxpSymlinkPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Targetpath) +
+ 4 + len(p.Linkpath)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpSymlink)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Targetpath)
+ b = marshalString(b, p.Linkpath)
+
+ return b, nil
+}
+
+func (p *sshFxpSymlinkPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Targetpath, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Linkpath, _, err = unmarshalStringSafe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+type sshFxpHardlinkPacket struct {
+ ID uint32
+ Oldpath string
+ Newpath string
+}
+
+func (p *sshFxpHardlinkPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpHardlinkPacket) MarshalBinary() ([]byte, error) {
+ const ext = "hardlink@openssh.com"
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(ext) +
+ 4 + len(p.Oldpath) +
+ 4 + len(p.Newpath)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpExtended)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, ext)
+ b = marshalString(b, p.Oldpath)
+ b = marshalString(b, p.Newpath)
+
+ return b, nil
+}
+
+type sshFxpReadlinkPacket struct {
+ ID uint32
+ Path string
+}
+
+func (p *sshFxpReadlinkPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpReadlink, p.ID, p.Path)
+}
+
+func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Path)
+}
+
+type sshFxpRealpathPacket struct {
+ ID uint32
+ Path string
+}
+
+func (p *sshFxpRealpathPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpRealpathPacket) MarshalBinary() ([]byte, error) {
+ return marshalIDStringPacket(sshFxpRealpath, p.ID, p.Path)
+}
+
+func (p *sshFxpRealpathPacket) UnmarshalBinary(b []byte) error {
+ return unmarshalIDString(b, &p.ID, &p.Path)
+}
+
+type sshFxpNameAttr struct {
+ Name string
+ LongName string
+ Attrs []interface{}
+}
+
+func (p *sshFxpNameAttr) MarshalBinary() ([]byte, error) {
+ var b []byte
+ b = marshalString(b, p.Name)
+ b = marshalString(b, p.LongName)
+ for _, attr := range p.Attrs {
+ b = marshal(b, attr)
+ }
+ return b, nil
+}
+
+type sshFxpNamePacket struct {
+ ID uint32
+ NameAttrs []*sshFxpNameAttr
+}
+
+func (p *sshFxpNamePacket) marshalPacket() ([]byte, []byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpName)
+ b = marshalUint32(b, p.ID)
+ b = marshalUint32(b, uint32(len(p.NameAttrs)))
+
+ var payload []byte
+ for _, na := range p.NameAttrs {
+ ab, err := na.MarshalBinary()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ payload = append(payload, ab...)
+ }
+
+ return b, payload, nil
+}
+
+func (p *sshFxpNamePacket) MarshalBinary() ([]byte, error) {
+ header, payload, err := p.marshalPacket()
+ return append(header, payload...), err
+}
+
+type sshFxpOpenPacket struct {
+ ID uint32
+ Path string
+ Pflags uint32
+ Flags uint32 // ignored
+}
+
+func (p *sshFxpOpenPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Path) +
+ 4 + 4
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpOpen)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Path)
+ b = marshalUint32(b, p.Pflags)
+ b = marshalUint32(b, p.Flags)
+
+ return b, nil
+}
+
+func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+type sshFxpReadPacket struct {
+ ID uint32
+ Len uint32
+ Offset uint64
+ Handle string
+}
+
+func (p *sshFxpReadPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpReadPacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Handle) +
+ 8 + 4 // uint64 + uint32
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpRead)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Handle)
+ b = marshalUint64(b, p.Offset)
+ b = marshalUint32(b, p.Len)
+
+ return b, nil
+}
+
+func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil {
+ return err
+ } else if p.Len, _, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+// We need allocate bigger slices with extra capacity to avoid a re-allocation in sshFxpDataPacket.MarshalBinary
+// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
+const dataHeaderLen = 4 + 1 + 4 + 4
+
+func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
+ dataLen := p.Len
+ if dataLen > maxTxPacket {
+ dataLen = maxTxPacket
+ }
+
+ if alloc != nil {
+ // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in
+ // sshFxpDataPacket.MarshalBinary
+ return alloc.GetPage(orderID)[:dataLen]
+ }
+
+ // allocate with extra space for the header
+ return make([]byte, dataLen, dataLen+dataHeaderLen)
+}
+
+type sshFxpRenamePacket struct {
+ ID uint32
+ Oldpath string
+ Newpath string
+}
+
+func (p *sshFxpRenamePacket) id() uint32 { return p.ID }
+
+func (p *sshFxpRenamePacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Oldpath) +
+ 4 + len(p.Newpath)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpRename)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Oldpath)
+ b = marshalString(b, p.Newpath)
+
+ return b, nil
+}
+
+func (p *sshFxpRenamePacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+type sshFxpPosixRenamePacket struct {
+ ID uint32
+ Oldpath string
+ Newpath string
+}
+
+func (p *sshFxpPosixRenamePacket) id() uint32 { return p.ID }
+
+func (p *sshFxpPosixRenamePacket) MarshalBinary() ([]byte, error) {
+ const ext = "posix-rename@openssh.com"
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(ext) +
+ 4 + len(p.Oldpath) +
+ 4 + len(p.Newpath)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpExtended)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, ext)
+ b = marshalString(b, p.Oldpath)
+ b = marshalString(b, p.Newpath)
+
+ return b, nil
+}
+
+type sshFxpWritePacket struct {
+ ID uint32
+ Length uint32
+ Offset uint64
+ Handle string
+ Data []byte
+}
+
+func (p *sshFxpWritePacket) id() uint32 { return p.ID }
+
+func (p *sshFxpWritePacket) marshalPacket() ([]byte, []byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Handle) +
+ 8 + // uint64
+ 4
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpWrite)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Handle)
+ b = marshalUint64(b, p.Offset)
+ b = marshalUint32(b, p.Length)
+
+ return b, p.Data, nil
+}
+
+func (p *sshFxpWritePacket) MarshalBinary() ([]byte, error) {
+ header, payload, err := p.marshalPacket()
+ return append(header, payload...), err
+}
+
+func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil {
+ return err
+ } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if uint32(len(b)) < p.Length {
+ return errShortPacket
+ }
+
+ p.Data = b[:p.Length]
+ return nil
+}
+
+type sshFxpMkdirPacket struct {
+ ID uint32
+ Flags uint32 // ignored
+ Path string
+}
+
+func (p *sshFxpMkdirPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Path) +
+ 4 // uint32
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpMkdir)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Path)
+ b = marshalUint32(b, p.Flags)
+
+ return b, nil
+}
+
+func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+type sshFxpSetstatPacket struct {
+ ID uint32
+ Flags uint32
+ Path string
+ Attrs interface{}
+}
+
+type sshFxpFsetstatPacket struct {
+ ID uint32
+ Flags uint32
+ Handle string
+ Attrs interface{}
+}
+
+func (p *sshFxpSetstatPacket) id() uint32 { return p.ID }
+func (p *sshFxpFsetstatPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Path) +
+ 4 // uint32
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpSetstat)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Path)
+ b = marshalUint32(b, p.Flags)
+
+ payload := marshal(nil, p.Attrs)
+
+ return b, payload, nil
+}
+
+func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) {
+ header, payload, err := p.marshalPacket()
+ return append(header, payload...), err
+}
+
+func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Handle) +
+ 4 // uint32
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpFsetstat)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Handle)
+ b = marshalUint32(b, p.Flags)
+
+ payload := marshal(nil, p.Attrs)
+
+ return b, payload, nil
+}
+
+func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) {
+ header, payload, err := p.marshalPacket()
+ return append(header, payload...), err
+}
+
+func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Path, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ }
+ p.Attrs = b
+ return nil
+}
+
+func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ }
+ p.Attrs = b
+ return nil
+}
+
+type sshFxpHandlePacket struct {
+ ID uint32
+ Handle string
+}
+
+func (p *sshFxpHandlePacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(p.Handle)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpHandle)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, p.Handle)
+
+ return b, nil
+}
+
+type sshFxpStatusPacket struct {
+ ID uint32
+ StatusError
+}
+
+func (p *sshFxpStatusPacket) MarshalBinary() ([]byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 +
+ 4 + len(p.StatusError.msg) +
+ 4 + len(p.StatusError.lang)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpStatus)
+ b = marshalUint32(b, p.ID)
+ b = marshalStatus(b, p.StatusError)
+
+ return b, nil
+}
+
+type sshFxpDataPacket struct {
+ ID uint32
+ Length uint32
+ Data []byte
+}
+
+func (p *sshFxpDataPacket) marshalPacket() ([]byte, []byte, error) {
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpData)
+ b = marshalUint32(b, p.ID)
+ b = marshalUint32(b, p.Length)
+
+ return b, p.Data, nil
+}
+
+// MarshalBinary encodes the receiver into a binary form and returns the result.
+// To avoid a new allocation the Data slice must have a capacity >= Length + 9
+//
+// This is hand-coded rather than just append(header, payload...),
+// in order to try and reuse the r.Data backing store in the packet.
+func (p *sshFxpDataPacket) MarshalBinary() ([]byte, error) {
+ b := append(p.Data, make([]byte, dataHeaderLen)...)
+ copy(b[dataHeaderLen:], p.Data[:p.Length])
+ // b[0:4] will be overwritten with the length in sendPacket
+ b[4] = sshFxpData
+ binary.BigEndian.PutUint32(b[5:9], p.ID)
+ binary.BigEndian.PutUint32(b[9:13], p.Length)
+ return b, nil
+}
+
+func (p *sshFxpDataPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if uint32(len(b)) < p.Length {
+ return errShortPacket
+ }
+
+ p.Data = b[:p.Length]
+ return nil
+}
+
+type sshFxpStatvfsPacket struct {
+ ID uint32
+ Path string
+}
+
+func (p *sshFxpStatvfsPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpStatvfsPacket) MarshalBinary() ([]byte, error) {
+ const ext = "statvfs@openssh.com"
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(ext) +
+ 4 + len(p.Path)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpExtended)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, ext)
+ b = marshalString(b, p.Path)
+
+ return b, nil
+}
+
+// A StatVFS contains statistics about a filesystem.
+type StatVFS struct {
+ ID uint32
+ Bsize uint64 /* file system block size */
+ Frsize uint64 /* fundamental fs block size */
+ Blocks uint64 /* number of blocks (unit f_frsize) */
+ Bfree uint64 /* free blocks in file system */
+ Bavail uint64 /* free blocks for non-root */
+ Files uint64 /* total file inodes */
+ Ffree uint64 /* free file inodes */
+ Favail uint64 /* free file inodes for to non-root */
+ Fsid uint64 /* file system id */
+ Flag uint64 /* bit mask of f_flag values */
+ Namemax uint64 /* maximum filename length */
+}
+
+// TotalSpace calculates the amount of total space in a filesystem.
+func (p *StatVFS) TotalSpace() uint64 {
+ return p.Frsize * p.Blocks
+}
+
+// FreeSpace calculates the amount of free space in a filesystem.
+func (p *StatVFS) FreeSpace() uint64 {
+ return p.Frsize * p.Bfree
+}
+
+// marshalPacket converts to ssh_FXP_EXTENDED_REPLY packet binary format
+func (p *StatVFS) marshalPacket() ([]byte, []byte, error) {
+ header := []byte{0, 0, 0, 0, sshFxpExtendedReply}
+
+ var buf bytes.Buffer
+ err := binary.Write(&buf, binary.BigEndian, p)
+
+ return header, buf.Bytes(), err
+}
+
+// MarshalBinary encodes the StatVFS as an SSH_FXP_EXTENDED_REPLY packet.
+func (p *StatVFS) MarshalBinary() ([]byte, error) {
+ header, payload, err := p.marshalPacket()
+ return append(header, payload...), err
+}
+
+type sshFxpFsyncPacket struct {
+ ID uint32
+ Handle string
+}
+
+func (p *sshFxpFsyncPacket) id() uint32 { return p.ID }
+
+func (p *sshFxpFsyncPacket) MarshalBinary() ([]byte, error) {
+ const ext = "fsync@openssh.com"
+ l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
+ 4 + len(ext) +
+ 4 + len(p.Handle)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpExtended)
+ b = marshalUint32(b, p.ID)
+ b = marshalString(b, ext)
+ b = marshalString(b, p.Handle)
+
+ return b, nil
+}
+
+type sshFxpExtendedPacket struct {
+ ID uint32
+ ExtendedRequest string
+ SpecificPacket interface {
+ serverRespondablePacket
+ readonly() bool
+ }
+}
+
+func (p *sshFxpExtendedPacket) id() uint32 { return p.ID }
+func (p *sshFxpExtendedPacket) readonly() bool {
+ if p.SpecificPacket == nil {
+ return true
+ }
+ return p.SpecificPacket.readonly()
+}
+
+func (p *sshFxpExtendedPacket) respond(svr *Server) responsePacket {
+ if p.SpecificPacket == nil {
+ return statusFromError(p.ID, nil)
+ }
+ return p.SpecificPacket.respond(svr)
+}
+
+func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error {
+ var err error
+ bOrig := b
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.ExtendedRequest, _, err = unmarshalStringSafe(b); err != nil {
+ return err
+ }
+
+ // specific unmarshalling
+ switch p.ExtendedRequest {
+ case "statvfs@openssh.com":
+ p.SpecificPacket = &sshFxpExtendedPacketStatVFS{}
+ case "posix-rename@openssh.com":
+ p.SpecificPacket = &sshFxpExtendedPacketPosixRename{}
+ case "hardlink@openssh.com":
+ p.SpecificPacket = &sshFxpExtendedPacketHardlink{}
+ default:
+ return fmt.Errorf("packet type %v: %w", p.SpecificPacket, errUnknownExtendedPacket)
+ }
+
+ return p.SpecificPacket.UnmarshalBinary(bOrig)
+}
+
+type sshFxpExtendedPacketStatVFS struct {
+ ID uint32
+ ExtendedRequest string
+ Path string
+}
+
+func (p *sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID }
+func (p *sshFxpExtendedPacketStatVFS) readonly() bool { return true }
+func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Path, _, err = unmarshalStringSafe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+type sshFxpExtendedPacketPosixRename struct {
+ ID uint32
+ ExtendedRequest string
+ Oldpath string
+ Newpath string
+}
+
+func (p *sshFxpExtendedPacketPosixRename) id() uint32 { return p.ID }
+func (p *sshFxpExtendedPacketPosixRename) readonly() bool { return false }
+func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket {
+ err := os.Rename(p.Oldpath, p.Newpath)
+ return statusFromError(p.ID, err)
+}
+
+type sshFxpExtendedPacketHardlink struct {
+ ID uint32
+ ExtendedRequest string
+ Oldpath string
+ Newpath string
+}
+
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
+func (p *sshFxpExtendedPacketHardlink) id() uint32 { return p.ID }
+func (p *sshFxpExtendedPacketHardlink) readonly() bool { return true }
+func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error {
+ var err error
+ if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
+ return err
+ } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil {
+ return err
+ } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket {
+ err := os.Link(p.Oldpath, p.Newpath)
+ return statusFromError(p.ID, err)
+}
diff --git a/sftp/packet_manager.go b/sftp/packet_manager.go
new file mode 100644
index 0000000..5aeb72b
--- /dev/null
+++ b/sftp/packet_manager.go
@@ -0,0 +1,221 @@
+package sftp
+
+/*
+ Imported from: https://github.com/pkg/sftp
+ */
+
+import (
+ "encoding"
+ "sort"
+ "sync"
+)
+
+// The goal of the packetManager is to keep the outgoing packets in the same
+// order as the incoming as is requires by section 7 of the RFC.
+
+type packetManager struct {
+ requests chan orderedPacket
+ responses chan orderedPacket
+ fini chan struct{}
+ incoming orderedPackets
+ outgoing orderedPackets
+ sender packetSender // connection object
+ working *sync.WaitGroup
+ packetCount uint32
+ // it is not nil if the allocator is enabled
+ alloc *allocator
+}
+
+type packetSender interface {
+ sendPacket(encoding.BinaryMarshaler) error
+}
+
+func newPktMgr(sender packetSender) *packetManager {
+ s := &packetManager{
+ requests: make(chan orderedPacket, SftpServerWorkerCount),
+ responses: make(chan orderedPacket, SftpServerWorkerCount),
+ fini: make(chan struct{}),
+ incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
+ outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
+ sender: sender,
+ working: &sync.WaitGroup{},
+ }
+ go s.controller()
+ return s
+}
+
+//// packet ordering
+func (s *packetManager) newOrderID() uint32 {
+ s.packetCount++
+ return s.packetCount
+}
+
+// returns the next orderID without incrementing it.
+// This is used before receiving a new packet, with the allocator enabled, to associate
+// the slice allocated for the received packet with the orderID that will be used to mark
+// the allocated slices for reuse once the request is served
+func (s *packetManager) getNextOrderID() uint32 {
+ return s.packetCount + 1
+}
+
+type orderedRequest struct {
+ requestPacket
+ orderid uint32
+}
+
+func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
+ return orderedRequest{requestPacket: p, orderid: s.newOrderID()}
+}
+func (p orderedRequest) orderID() uint32 { return p.orderid }
+func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid }
+
+type orderedResponse struct {
+ responsePacket
+ orderid uint32
+}
+
+func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
+) orderedResponse {
+ return orderedResponse{responsePacket: p, orderid: id}
+}
+func (p orderedResponse) orderID() uint32 { return p.orderid }
+func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid }
+
+type orderedPacket interface {
+ id() uint32
+ orderID() uint32
+}
+type orderedPackets []orderedPacket
+
+func (o orderedPackets) Sort() {
+ sort.Slice(o, func(i, j int) bool {
+ return o[i].orderID() < o[j].orderID()
+ })
+}
+
+//// packet registry
+// register incoming packets to be handled
+func (s *packetManager) incomingPacket(pkt orderedRequest) {
+ s.working.Add(1)
+ s.requests <- pkt
+}
+
+// register outgoing packets as being ready
+func (s *packetManager) readyPacket(pkt orderedResponse) {
+ s.responses <- pkt
+ s.working.Done()
+}
+
+// shut down packetManager controller
+func (s *packetManager) close() {
+ // pause until current packets are processed
+ s.working.Wait()
+ close(s.fini)
+}
+
+// Passed a worker function, returns a channel for incoming packets.
+// Keep process packet responses in the order they are received while
+// maximizing throughput of file transfers.
+func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
+) chan orderedRequest {
+ // multiple workers for faster read/writes
+ rwChan := make(chan orderedRequest, SftpServerWorkerCount)
+ for i := 0; i < SftpServerWorkerCount; i++ {
+ runWorker(rwChan)
+ }
+
+ // single worker to enforce sequential processing of everything else
+ cmdChan := make(chan orderedRequest)
+ runWorker(cmdChan)
+
+ pktChan := make(chan orderedRequest, SftpServerWorkerCount)
+ go func() {
+ for pkt := range pktChan {
+ switch pkt.requestPacket.(type) {
+ case *sshFxpReadPacket, *sshFxpWritePacket:
+ s.incomingPacket(pkt)
+ rwChan <- pkt
+ continue
+ case *sshFxpClosePacket:
+ // wait for reads/writes to finish when file is closed
+ // incomingPacket() call must occur after this
+ s.working.Wait()
+ }
+ s.incomingPacket(pkt)
+ // all non-RW use sequential cmdChan
+ cmdChan <- pkt
+ }
+ close(rwChan)
+ close(cmdChan)
+ s.close()
+ }()
+
+ return pktChan
+}
+
+// process packets
+func (s *packetManager) controller() {
+ for {
+ select {
+ case pkt := <-s.requests:
+ debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID())
+ s.incoming = append(s.incoming, pkt)
+ s.incoming.Sort()
+ case pkt := <-s.responses:
+ debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID())
+ s.outgoing = append(s.outgoing, pkt)
+ s.outgoing.Sort()
+ case <-s.fini:
+ return
+ }
+ s.maybeSendPackets()
+ }
+}
+
+// send as many packets as are ready
+func (s *packetManager) maybeSendPackets() {
+ for {
+ if len(s.outgoing) == 0 || len(s.incoming) == 0 {
+ debug("break! -- outgoing: %v; incoming: %v",
+ len(s.outgoing), len(s.incoming))
+ break
+ }
+ out := s.outgoing[0]
+ in := s.incoming[0]
+ // debug("incoming: %v", ids(s.incoming))
+ // debug("outgoing: %v", ids(s.outgoing))
+ if in.orderID() == out.orderID() {
+ debug("Sending packet: %v", out.id())
+ s.sender.sendPacket(out.(encoding.BinaryMarshaler))
+ if s.alloc != nil {
+ // mark for reuse the slices allocated for this request
+ s.alloc.ReleasePages(in.orderID())
+ }
+ // pop off heads
+ copy(s.incoming, s.incoming[1:]) // shift left
+ s.incoming[len(s.incoming)-1] = nil // clear last
+ s.incoming = s.incoming[:len(s.incoming)-1] // remove last
+ copy(s.outgoing, s.outgoing[1:]) // shift left
+ s.outgoing[len(s.outgoing)-1] = nil // clear last
+ s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
+ } else {
+ break
+ }
+ }
+}
+
+// func oids(o []orderedPacket) []uint32 {
+// res := make([]uint32, 0, len(o))
+// for _, v := range o {
+// res = append(res, v.orderId())
+// }
+// return res
+// }
+// func ids(o []orderedPacket) []uint32 {
+// res := make([]uint32, 0, len(o))
+// for _, v := range o {
+// res = append(res, v.id())
+// }
+// return res
+// }
+
diff --git a/sftp/packet_typing.go b/sftp/packet_typing.go
new file mode 100644
index 0000000..f4f9052
--- /dev/null
+++ b/sftp/packet_typing.go
@@ -0,0 +1,135 @@
+package sftp
+
+import (
+ "encoding"
+ "fmt"
+)
+
+// all incoming packets
+type requestPacket interface {
+ encoding.BinaryUnmarshaler
+ id() uint32
+}
+
+type responsePacket interface {
+ encoding.BinaryMarshaler
+ id() uint32
+}
+
+// interfaces to group types
+type hasPath interface {
+ requestPacket
+ getPath() string
+}
+
+type hasHandle interface {
+ requestPacket
+ getHandle() string
+}
+
+type notReadOnly interface {
+ notReadOnly()
+}
+
+//// define types by adding methods
+// hasPath
+func (p *sshFxpLstatPacket) getPath() string { return p.Path }
+func (p *sshFxpStatPacket) getPath() string { return p.Path }
+func (p *sshFxpRmdirPacket) getPath() string { return p.Path }
+func (p *sshFxpReadlinkPacket) getPath() string { return p.Path }
+func (p *sshFxpRealpathPacket) getPath() string { return p.Path }
+func (p *sshFxpMkdirPacket) getPath() string { return p.Path }
+func (p *sshFxpSetstatPacket) getPath() string { return p.Path }
+func (p *sshFxpStatvfsPacket) getPath() string { return p.Path }
+func (p *sshFxpRemovePacket) getPath() string { return p.Filename }
+func (p *sshFxpRenamePacket) getPath() string { return p.Oldpath }
+func (p *sshFxpSymlinkPacket) getPath() string { return p.Targetpath }
+func (p *sshFxpOpendirPacket) getPath() string { return p.Path }
+func (p *sshFxpOpenPacket) getPath() string { return p.Path }
+
+func (p *sshFxpExtendedPacketPosixRename) getPath() string { return p.Oldpath }
+func (p *sshFxpExtendedPacketHardlink) getPath() string { return p.Oldpath }
+
+// getHandle
+func (p *sshFxpFstatPacket) getHandle() string { return p.Handle }
+func (p *sshFxpFsetstatPacket) getHandle() string { return p.Handle }
+func (p *sshFxpReadPacket) getHandle() string { return p.Handle }
+func (p *sshFxpWritePacket) getHandle() string { return p.Handle }
+func (p *sshFxpReaddirPacket) getHandle() string { return p.Handle }
+func (p *sshFxpClosePacket) getHandle() string { return p.Handle }
+
+// notReadOnly
+func (p *sshFxpWritePacket) notReadOnly() {}
+func (p *sshFxpSetstatPacket) notReadOnly() {}
+func (p *sshFxpFsetstatPacket) notReadOnly() {}
+func (p *sshFxpRemovePacket) notReadOnly() {}
+func (p *sshFxpMkdirPacket) notReadOnly() {}
+func (p *sshFxpRmdirPacket) notReadOnly() {}
+func (p *sshFxpRenamePacket) notReadOnly() {}
+func (p *sshFxpSymlinkPacket) notReadOnly() {}
+func (p *sshFxpExtendedPacketPosixRename) notReadOnly() {}
+func (p *sshFxpExtendedPacketHardlink) notReadOnly() {}
+
+// some packets with ID are missing id()
+func (p *sshFxpDataPacket) id() uint32 { return p.ID }
+func (p *sshFxpStatusPacket) id() uint32 { return p.ID }
+func (p *sshFxpStatResponse) id() uint32 { return p.ID }
+func (p *sshFxpNamePacket) id() uint32 { return p.ID }
+func (p *sshFxpHandlePacket) id() uint32 { return p.ID }
+func (p *StatVFS) id() uint32 { return p.ID }
+func (p *sshFxVersionPacket) id() uint32 { return 0 }
+
+// take raw incoming packet data and build packet objects
+func makePacket(p rxPacket) (requestPacket, error) {
+ var pkt requestPacket
+ switch p.pktType {
+ case sshFxpInit:
+ pkt = &sshFxInitPacket{}
+ case sshFxpLstat:
+ pkt = &sshFxpLstatPacket{}
+ case sshFxpOpen:
+ pkt = &sshFxpOpenPacket{}
+ case sshFxpClose:
+ pkt = &sshFxpClosePacket{}
+ case sshFxpRead:
+ pkt = &sshFxpReadPacket{}
+ case sshFxpWrite:
+ pkt = &sshFxpWritePacket{}
+ case sshFxpFstat:
+ pkt = &sshFxpFstatPacket{}
+ case sshFxpSetstat:
+ pkt = &sshFxpSetstatPacket{}
+ case sshFxpFsetstat:
+ pkt = &sshFxpFsetstatPacket{}
+ case sshFxpOpendir:
+ pkt = &sshFxpOpendirPacket{}
+ case sshFxpReaddir:
+ pkt = &sshFxpReaddirPacket{}
+ case sshFxpRemove:
+ pkt = &sshFxpRemovePacket{}
+ case sshFxpMkdir:
+ pkt = &sshFxpMkdirPacket{}
+ case sshFxpRmdir:
+ pkt = &sshFxpRmdirPacket{}
+ case sshFxpRealpath:
+ pkt = &sshFxpRealpathPacket{}
+ case sshFxpStat:
+ pkt = &sshFxpStatPacket{}
+ case sshFxpRename:
+ pkt = &sshFxpRenamePacket{}
+ case sshFxpReadlink:
+ pkt = &sshFxpReadlinkPacket{}
+ case sshFxpSymlink:
+ pkt = &sshFxpSymlinkPacket{}
+ case sshFxpExtended:
+ pkt = &sshFxpExtendedPacket{}
+ default:
+ return nil, fmt.Errorf("unhandled packet type: %s", p.pktType)
+ }
+ if err := pkt.UnmarshalBinary(p.pktBytes); err != nil {
+ // Return partially unpacked packet to allow callers to return
+ // error messages appropriately with necessary id() method.
+ return pkt, err
+ }
+ return pkt, nil
+}
diff --git a/sftp/pool.go b/sftp/pool.go
new file mode 100644
index 0000000..3612629
--- /dev/null
+++ b/sftp/pool.go
@@ -0,0 +1,79 @@
+package sftp
+
+// bufPool provides a pool of byte-slices to be reused in various parts of the package.
+// It is safe to use concurrently through a pointer.
+type bufPool struct {
+ ch chan []byte
+ blen int
+}
+
+func newBufPool(depth, bufLen int) *bufPool {
+ return &bufPool{
+ ch: make(chan []byte, depth),
+ blen: bufLen,
+ }
+}
+
+func (p *bufPool) Get() []byte {
+ if p.blen <= 0 {
+ panic("bufPool: new buffer creation length must be greater than zero")
+ }
+
+ for {
+ select {
+ case b := <-p.ch:
+ if cap(b) < p.blen {
+ // just in case: throw away any buffer with insufficient capacity.
+ continue
+ }
+
+ return b[:p.blen]
+
+ default:
+ return make([]byte, p.blen)
+ }
+ }
+}
+
+func (p *bufPool) Put(b []byte) {
+ if p == nil {
+ // functional default: no reuse.
+ return
+ }
+
+ if cap(b) < p.blen || cap(b) > p.blen*2 {
+ // DO NOT reuse buffers with insufficient capacity.
+ // This could cause panics when resizing to p.blen.
+
+ // DO NOT reuse buffers with excessive capacity.
+ // This could cause memory leaks.
+ return
+ }
+
+ select {
+ case p.ch <- b:
+ default:
+ }
+}
+
+type resChanPool chan chan result
+
+func newResChanPool(depth int) resChanPool {
+ return make(chan chan result, depth)
+}
+
+func (p resChanPool) Get() chan result {
+ select {
+ case ch := <-p:
+ return ch
+ default:
+ return make(chan result, 1)
+ }
+}
+
+func (p resChanPool) Put(ch chan result) {
+ select {
+ case p <- ch:
+ default:
+ }
+}
diff --git a/sftp/release.go b/sftp/release.go
new file mode 100644
index 0000000..b695528
--- /dev/null
+++ b/sftp/release.go
@@ -0,0 +1,5 @@
+// +build !debug
+
+package sftp
+
+func debug(fmt string, args ...interface{}) {}
diff --git a/sftp/request-attrs.go b/sftp/request-attrs.go
new file mode 100644
index 0000000..b5c95b4
--- /dev/null
+++ b/sftp/request-attrs.go
@@ -0,0 +1,63 @@
+package sftp
+
+// Methods on the Request object to make working with the Flags bitmasks and
+// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write
+// request and AttrFlags() and Attributes() when working with SetStat requests.
+import "os"
+
+// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags
+// (https://golang.org/pkg/os/#pkg-constants).
+type FileOpenFlags struct {
+ Read, Write, Append, Creat, Trunc, Excl bool
+}
+
+func newFileOpenFlags(flags uint32) FileOpenFlags {
+ return FileOpenFlags{
+ Read: flags&sshFxfRead != 0,
+ Write: flags&sshFxfWrite != 0,
+ Append: flags&sshFxfAppend != 0,
+ Creat: flags&sshFxfCreat != 0,
+ Trunc: flags&sshFxfTrunc != 0,
+ Excl: flags&sshFxfExcl != 0,
+ }
+}
+
+// Pflags converts the bitmap/uint32 from SFTP Open packet pflag values,
+// into a FileOpenFlags struct with booleans set for flags set in bitmap.
+func (r *Request) Pflags() FileOpenFlags {
+ return newFileOpenFlags(r.Flags)
+}
+
+// FileAttrFlags that indicate whether SFTP file attributes were passed. When a flag is
+// true the corresponding attribute should be available from the FileStat
+// object returned by Attributes method. Used with SetStat.
+type FileAttrFlags struct {
+ Size, UidGid, Permissions, Acmodtime bool
+}
+
+func newFileAttrFlags(flags uint32) FileAttrFlags {
+ return FileAttrFlags{
+ Size: (flags & sshFileXferAttrSize) != 0,
+ UidGid: (flags & sshFileXferAttrUIDGID) != 0,
+ Permissions: (flags & sshFileXferAttrPermissions) != 0,
+ Acmodtime: (flags & sshFileXferAttrACmodTime) != 0,
+ }
+}
+
+// AttrFlags returns a FileAttrFlags boolean struct based on the
+// bitmap/uint32 file attribute flags from the SFTP packaet.
+func (r *Request) AttrFlags() FileAttrFlags {
+ return newFileAttrFlags(r.Flags)
+}
+
+// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode
+func (a FileStat) FileMode() os.FileMode {
+ return os.FileMode(a.Mode)
+}
+
+// Attributes parses file attributes byte blob and return them in a
+// FileStat object.
+func (r *Request) Attributes() *FileStat {
+ fs, _ := unmarshalFileStat(r.Flags, r.Attrs)
+ return fs
+}
diff --git a/sftp/request-errors.go b/sftp/request-errors.go
new file mode 100644
index 0000000..6505b5c
--- /dev/null
+++ b/sftp/request-errors.go
@@ -0,0 +1,54 @@
+package sftp
+
+type fxerr uint32
+
+// Error types that match the SFTP's SSH_FXP_STATUS codes. Gives you more
+// direct control of the errors being sent vs. letting the library work them
+// out from the standard os/io errors.
+const (
+ ErrSSHFxOk = fxerr(sshFxOk)
+ ErrSSHFxEOF = fxerr(sshFxEOF)
+ ErrSSHFxNoSuchFile = fxerr(sshFxNoSuchFile)
+ ErrSSHFxPermissionDenied = fxerr(sshFxPermissionDenied)
+ ErrSSHFxFailure = fxerr(sshFxFailure)
+ ErrSSHFxBadMessage = fxerr(sshFxBadMessage)
+ ErrSSHFxNoConnection = fxerr(sshFxNoConnection)
+ ErrSSHFxConnectionLost = fxerr(sshFxConnectionLost)
+ ErrSSHFxOpUnsupported = fxerr(sshFxOPUnsupported)
+)
+
+// Deprecated error types, these are aliases for the new ones, please use the new ones directly
+const (
+ ErrSshFxOk = ErrSSHFxOk
+ ErrSshFxEof = ErrSSHFxEOF
+ ErrSshFxNoSuchFile = ErrSSHFxNoSuchFile
+ ErrSshFxPermissionDenied = ErrSSHFxPermissionDenied
+ ErrSshFxFailure = ErrSSHFxFailure
+ ErrSshFxBadMessage = ErrSSHFxBadMessage
+ ErrSshFxNoConnection = ErrSSHFxNoConnection
+ ErrSshFxConnectionLost = ErrSSHFxConnectionLost
+ ErrSshFxOpUnsupported = ErrSSHFxOpUnsupported
+)
+
+func (e fxerr) Error() string {
+ switch e {
+ case ErrSSHFxOk:
+ return "OK"
+ case ErrSSHFxEOF:
+ return "EOF"
+ case ErrSSHFxNoSuchFile:
+ return "no such file"
+ case ErrSSHFxPermissionDenied:
+ return "permission denied"
+ case ErrSSHFxBadMessage:
+ return "bad message"
+ case ErrSSHFxNoConnection:
+ return "no connection"
+ case ErrSSHFxConnectionLost:
+ return "connection lost"
+ case ErrSSHFxOpUnsupported:
+ return "operation unsupported"
+ default:
+ return "failure"
+ }
+}
diff --git a/sftp/request-interfaces.go b/sftp/request-interfaces.go
new file mode 100644
index 0000000..c8c424c
--- /dev/null
+++ b/sftp/request-interfaces.go
@@ -0,0 +1,121 @@
+package sftp
+
+import (
+ "io"
+ "os"
+)
+
+// WriterAtReaderAt defines the interface to return when a file is to
+// be opened for reading and writing
+type WriterAtReaderAt interface {
+ io.WriterAt
+ io.ReaderAt
+}
+
+// Interfaces are differentiated based on required returned values.
+// All input arguments are to be pulled from Request (the only arg).
+
+// The Handler interfaces all take the Request object as its only argument.
+// All the data you should need to handle the call are in the Request object.
+// The request.Method attribute is initially the most important one as it
+// determines which Handler gets called.
+
+// FileReader should return an io.ReaderAt for the filepath
+// Note in cases of an error, the error text will be sent to the client.
+// Called for Methods: Get
+type FileReader interface {
+ Fileread(*Request) (io.ReaderAt, error)
+}
+
+// FileWriter should return an io.WriterAt for the filepath.
+//
+// The request server code will call Close() on the returned io.WriterAt
+// ojbect if an io.Closer type assertion succeeds.
+// Note in cases of an error, the error text will be sent to the client.
+// Note when receiving an Append flag it is important to not open files using
+// O_APPEND if you plan to use WriteAt, as they conflict.
+// Called for Methods: Put, Open
+type FileWriter interface {
+ Filewrite(*Request) (io.WriterAt, error)
+}
+
+// OpenFileWriter is a FileWriter that implements the generic OpenFile method.
+// You need to implement this optional interface if you want to be able
+// to read and write from/to the same handle.
+// Called for Methods: Open
+type OpenFileWriter interface {
+ FileWriter
+ OpenFile(*Request) (WriterAtReaderAt, error)
+}
+
+// FileCmder should return an error
+// Note in cases of an error, the error text will be sent to the client.
+// Called for Methods: Setstat, Rename, Rmdir, Mkdir, Link, Symlink, Remove
+type FileCmder interface {
+ Filecmd(*Request) error
+}
+
+// PosixRenameFileCmder is a FileCmder that implements the PosixRename method.
+// If this interface is implemented PosixRename requests will call it
+// otherwise they will be handled in the same way as Rename
+type PosixRenameFileCmder interface {
+ FileCmder
+ PosixRename(*Request) error
+}
+
+// StatVFSFileCmder is a FileCmder that implements the StatVFS method.
+// You need to implement this interface if you want to handle statvfs requests.
+// Please also be sure that the statvfs@openssh.com extension is enabled
+type StatVFSFileCmder interface {
+ FileCmder
+ StatVFS(*Request) (*StatVFS, error)
+}
+
+// FileLister should return an object that fulfils the ListerAt interface
+// Note in cases of an error, the error text will be sent to the client.
+// Called for Methods: List, Stat, Readlink
+type FileLister interface {
+ Filelist(*Request) (ListerAt, error)
+}
+
+// LstatFileLister is a FileLister that implements the Lstat method.
+// If this interface is implemented Lstat requests will call it
+// otherwise they will be handled in the same way as Stat
+type LstatFileLister interface {
+ FileLister
+ Lstat(*Request) (ListerAt, error)
+}
+
+// RealPathFileLister is a FileLister that implements the Realpath method.
+// We use "/" as start directory for relative paths, implementing this
+// interface you can customize the start directory.
+// You have to return an absolute POSIX path.
+type RealPathFileLister interface {
+ FileLister
+ RealPath(string) string
+}
+
+// NameLookupFileLister is a FileLister that implmeents the LookupUsername and LookupGroupName methods.
+// If this interface is implemented, then longname ls formatting will use these to convert usernames and groupnames.
+type NameLookupFileLister interface {
+ FileLister
+ LookupUserName(string) string
+ LookupGroupName(string) string
+}
+
+// ListerAt does for file lists what io.ReaderAt does for files.
+// ListAt should return the number of entries copied and an io.EOF
+// error if at end of list. This is testable by comparing how many you
+// copied to how many could be copied (eg. n < len(ls) below).
+// The copy() builtin is best for the copying.
+// Note in cases of an error, the error text will be sent to the client.
+type ListerAt interface {
+ ListAt([]os.FileInfo, int64) (int, error)
+}
+
+// TransferError is an optional interface that readerAt and writerAt
+// can implement to be notified about the error causing Serve() to exit
+// with the request still open
+type TransferError interface {
+ TransferError(err error)
+}
diff --git a/sftp/request-plan9.go b/sftp/request-plan9.go
new file mode 100644
index 0000000..2444da5
--- /dev/null
+++ b/sftp/request-plan9.go
@@ -0,0 +1,34 @@
+// +build plan9
+
+package sftp
+
+import (
+ "path"
+ "path/filepath"
+ "syscall"
+)
+
+func fakeFileInfoSys() interface{} {
+ return &syscall.Dir{}
+}
+
+func testOsSys(sys interface{}) error {
+ return nil
+}
+
+func toLocalPath(p string) string {
+ lp := filepath.FromSlash(p)
+
+ if path.IsAbs(p) {
+ tmp := lp[1:]
+
+ if filepath.IsAbs(tmp) {
+ // If the FromSlash without any starting slashes is absolute,
+ // then we have a filepath encoded with a prefix '/'.
+ // e.g. "/#s/boot" to "#s/boot"
+ return tmp
+ }
+ }
+
+ return lp
+}
diff --git a/sftp/request-server.go b/sftp/request-server.go
new file mode 100644
index 0000000..5fa828b
--- /dev/null
+++ b/sftp/request-server.go
@@ -0,0 +1,304 @@
+package sftp
+
+import (
+ "context"
+ "errors"
+ "io"
+ "path"
+ "path/filepath"
+ "strconv"
+ "sync"
+)
+
+var maxTxPacket uint32 = 1 << 15
+
+// Handlers contains the 4 SFTP server request handlers.
+type Handlers struct {
+ FileGet FileReader
+ FilePut FileWriter
+ FileCmd FileCmder
+ FileList FileLister
+}
+
+// RequestServer abstracts the sftp protocol with an http request-like protocol
+type RequestServer struct {
+ Handlers Handlers
+
+ *serverConn
+ pktMgr *packetManager
+
+ mu sync.RWMutex
+ handleCount int
+ openRequests map[string]*Request
+}
+
+// A RequestServerOption is a function which applies configuration to a RequestServer.
+type RequestServerOption func(*RequestServer)
+
+// WithRSAllocator enable the allocator.
+// After processing a packet we keep in memory the allocated slices
+// and we reuse them for new packets.
+// The allocator is experimental
+func WithRSAllocator() RequestServerOption {
+ return func(rs *RequestServer) {
+ alloc := newAllocator()
+ rs.pktMgr.alloc = alloc
+ rs.conn.alloc = alloc
+ }
+}
+
+// NewRequestServer creates/allocates/returns new RequestServer.
+// Normally there will be one server per user-session.
+func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
+ svrConn := &serverConn{
+ conn: conn{
+ Reader: rwc,
+ WriteCloser: rwc,
+ },
+ }
+ rs := &RequestServer{
+ Handlers: h,
+
+ serverConn: svrConn,
+ pktMgr: newPktMgr(svrConn),
+
+ openRequests: make(map[string]*Request),
+ }
+
+ for _, o := range options {
+ o(rs)
+ }
+ return rs
+}
+
+// New Open packet/Request
+func (rs *RequestServer) nextRequest(r *Request) string {
+ rs.mu.Lock()
+ defer rs.mu.Unlock()
+
+ rs.handleCount++
+
+ r.handle = strconv.Itoa(rs.handleCount)
+ rs.openRequests[r.handle] = r
+
+ return r.handle
+}
+
+// Returns Request from openRequests, bool is false if it is missing.
+//
+// The Requests in openRequests work essentially as open file descriptors that
+// you can do different things with. What you are doing with it are denoted by
+// the first packet of that type (read/write/etc).
+func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
+ rs.mu.RLock()
+ defer rs.mu.RUnlock()
+
+ r, ok := rs.openRequests[handle]
+ return r, ok
+}
+
+// Close the Request and clear from openRequests map
+func (rs *RequestServer) closeRequest(handle string) error {
+ rs.mu.Lock()
+ defer rs.mu.Unlock()
+
+ if r, ok := rs.openRequests[handle]; ok {
+ delete(rs.openRequests, handle)
+ return r.close()
+ }
+
+ return EBADF
+}
+
+// Close the read/write/closer to trigger exiting the main server loop
+func (rs *RequestServer) Close() error { return rs.conn.Close() }
+
+func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
+ defer close(pktChan) // shuts down sftpServerWorkers
+
+ var err error
+ var pkt requestPacket
+ var pktType uint8
+ var pktBytes []byte
+
+ for {
+ pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID())
+ if err != nil {
+ // we don't care about releasing allocated pages here, the server will quit and the allocator freed
+ return err
+ }
+
+ pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
+ if err != nil {
+ switch {
+ case errors.Is(err, errUnknownExtendedPacket):
+ // do nothing
+ default:
+ debug("makePacket err: %v", err)
+ rs.conn.Close() // shuts down recvPacket
+ return err
+ }
+ }
+
+ pktChan <- rs.pktMgr.newOrderedRequest(pkt)
+ }
+}
+
+// Serve requests for user session
+func (rs *RequestServer) Serve() error {
+ defer func() {
+ if rs.pktMgr.alloc != nil {
+ rs.pktMgr.alloc.Free()
+ }
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ var wg sync.WaitGroup
+ runWorker := func(ch chan orderedRequest) {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := rs.packetWorker(ctx, ch); err != nil {
+ rs.conn.Close() // shuts down recvPacket
+ }
+ }()
+ }
+ pktChan := rs.pktMgr.workerChan(runWorker)
+
+ err := rs.serveLoop(pktChan)
+
+ wg.Wait() // wait for all workers to exit
+
+ rs.mu.Lock()
+ defer rs.mu.Unlock()
+
+ // make sure all open requests are properly closed
+ // (eg. possible on dropped connections, client crashes, etc.)
+ for handle, req := range rs.openRequests {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ req.transferError(err)
+
+ delete(rs.openRequests, handle)
+ req.close()
+ }
+
+ return err
+}
+
+func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error {
+ for pkt := range pktChan {
+ orderID := pkt.orderID()
+ if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok {
+ if epkt.SpecificPacket != nil {
+ pkt.requestPacket = epkt.SpecificPacket
+ }
+ }
+
+ var rpkt responsePacket
+ switch pkt := pkt.requestPacket.(type) {
+ case *sshFxInitPacket:
+ rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions}
+ case *sshFxpClosePacket:
+ handle := pkt.getHandle()
+ rpkt = statusFromError(pkt.ID, rs.closeRequest(handle))
+ case *sshFxpRealpathPacket:
+ var realPath string
+ if realPather, ok := rs.Handlers.FileList.(RealPathFileLister); ok {
+ realPath = realPather.RealPath(pkt.getPath())
+ } else {
+ realPath = cleanPath(pkt.getPath())
+ }
+ rpkt = cleanPacketPath(pkt, realPath)
+ case *sshFxpOpendirPacket:
+ request := requestFromPacket(ctx, pkt)
+ handle := rs.nextRequest(request)
+ rpkt = request.opendir(rs.Handlers, pkt)
+ if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
+ // if we return an error we have to remove the handle from the active ones
+ rs.closeRequest(handle)
+ }
+ case *sshFxpOpenPacket:
+ request := requestFromPacket(ctx, pkt)
+ handle := rs.nextRequest(request)
+ rpkt = request.open(rs.Handlers, pkt)
+ if _, ok := rpkt.(*sshFxpHandlePacket); !ok {
+ // if we return an error we have to remove the handle from the active ones
+ rs.closeRequest(handle)
+ }
+ case *sshFxpFstatPacket:
+ handle := pkt.getHandle()
+ request, ok := rs.getRequest(handle)
+ if !ok {
+ rpkt = statusFromError(pkt.ID, EBADF)
+ } else {
+ request = NewRequest("Stat", request.Filepath)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ }
+ case *sshFxpFsetstatPacket:
+ handle := pkt.getHandle()
+ request, ok := rs.getRequest(handle)
+ if !ok {
+ rpkt = statusFromError(pkt.ID, EBADF)
+ } else {
+ request = NewRequest("Setstat", request.Filepath)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ }
+ case *sshFxpExtendedPacketPosixRename:
+ request := NewRequest("PosixRename", pkt.Oldpath)
+ request.Target = pkt.Newpath
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ case *sshFxpExtendedPacketStatVFS:
+ request := NewRequest("StatVFS", pkt.Path)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ case hasHandle:
+ handle := pkt.getHandle()
+ request, ok := rs.getRequest(handle)
+ if !ok {
+ rpkt = statusFromError(pkt.id(), EBADF)
+ } else {
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ }
+ case hasPath:
+ request := requestFromPacket(ctx, pkt)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ request.close()
+ default:
+ rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
+ }
+
+ rs.pktMgr.readyPacket(
+ rs.pktMgr.newOrderedResponse(rpkt, orderID))
+ }
+ return nil
+}
+
+// clean and return name packet for file
+func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket {
+ return &sshFxpNamePacket{
+ ID: pkt.id(),
+ NameAttrs: []*sshFxpNameAttr{
+ {
+ Name: realPath,
+ LongName: realPath,
+ Attrs: emptyFileStat,
+ },
+ },
+ }
+}
+
+// Makes sure we have a clean POSIX (/) absolute path to work with
+func cleanPath(p string) string {
+ return cleanPathWithBase("/", p)
+}
+
+func cleanPathWithBase(base, p string) string {
+ p = filepath.ToSlash(filepath.Clean(p))
+ if !path.IsAbs(p) {
+ return path.Join(base, p)
+ }
+ return p
+}
diff --git a/sftp/request-unix.go b/sftp/request-unix.go
new file mode 100644
index 0000000..50b08a3
--- /dev/null
+++ b/sftp/request-unix.go
@@ -0,0 +1,27 @@
+// +build !windows,!plan9
+
+package sftp
+
+import (
+ "errors"
+ "syscall"
+)
+
+func fakeFileInfoSys() interface{} {
+ return &syscall.Stat_t{Uid: 65534, Gid: 65534}
+}
+
+func testOsSys(sys interface{}) error {
+ fstat := sys.(*FileStat)
+ if fstat.UID != uint32(65534) {
+ return errors.New("Uid failed to match")
+ }
+ if fstat.GID != uint32(65534) {
+ return errors.New("Gid failed to match")
+ }
+ return nil
+}
+
+func toLocalPath(p string) string {
+ return p
+}
diff --git a/sftp/request.go b/sftp/request.go
new file mode 100644
index 0000000..c6da4b6
--- /dev/null
+++ b/sftp/request.go
@@ -0,0 +1,628 @@
+package sftp
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+ "sync"
+ "syscall"
+)
+
+// MaxFilelist is the max number of files to return in a readdir batch.
+var MaxFilelist int64 = 100
+
+// state encapsulates the reader/writer/readdir from handlers.
+type state struct {
+ mu sync.RWMutex
+
+ writerAt io.WriterAt
+ readerAt io.ReaderAt
+ writerAtReaderAt WriterAtReaderAt
+ listerAt ListerAt
+ lsoffset int64
+}
+
+// copy returns a shallow copy the state.
+// This is broken out to specific fields,
+// because we have to copy around the mutex in state.
+func (s *state) copy() state {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return state{
+ writerAt: s.writerAt,
+ readerAt: s.readerAt,
+ writerAtReaderAt: s.writerAtReaderAt,
+ listerAt: s.listerAt,
+ lsoffset: s.lsoffset,
+ }
+}
+
+func (s *state) setReaderAt(rd io.ReaderAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.readerAt = rd
+}
+
+func (s *state) getReaderAt() io.ReaderAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.readerAt
+}
+
+func (s *state) setWriterAt(rd io.WriterAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.writerAt = rd
+}
+
+func (s *state) getWriterAt() io.WriterAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.writerAt
+}
+
+func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.writerAtReaderAt = rw
+}
+
+func (s *state) getWriterAtReaderAt() WriterAtReaderAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.writerAtReaderAt
+}
+
+func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.readerAt, s.writerAt, s.writerAtReaderAt
+}
+
+// Returns current offset for file list
+func (s *state) lsNext() int64 {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.lsoffset
+}
+
+// Increases next offset
+func (s *state) lsInc(offset int64) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.lsoffset += offset
+}
+
+// manage file read/write state
+func (s *state) setListerAt(la ListerAt) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.listerAt = la
+}
+
+func (s *state) getListerAt() ListerAt {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ return s.listerAt
+}
+
+// Request contains the data and state for the incoming service request.
+type Request struct {
+ // Get, Put, Setstat, Stat, Rename, Remove
+ // Rmdir, Mkdir, List, Readlink, Link, Symlink
+ Method string
+ Filepath string
+ Flags uint32
+ Attrs []byte // convert to sub-struct
+ Target string // for renames and sym-links
+ handle string
+
+ // reader/writer/readdir from handlers
+ state
+
+ // context lasts duration of request
+ ctx context.Context
+ cancelCtx context.CancelFunc
+}
+
+// NewRequest creates a new Request object.
+func NewRequest(method, path string) *Request {
+ return &Request{
+ Method: method,
+ Filepath: cleanPath(path),
+ }
+}
+
+// copy returns a shallow copy of existing request.
+// This is broken out to specific fields,
+// because we have to copy around the mutex in state.
+func (r *Request) copy() *Request {
+ return &Request{
+ Method: r.Method,
+ Filepath: r.Filepath,
+ Flags: r.Flags,
+ Attrs: r.Attrs,
+ Target: r.Target,
+ handle: r.handle,
+
+ state: r.state.copy(),
+
+ ctx: r.ctx,
+ cancelCtx: r.cancelCtx,
+ }
+}
+
+// New Request initialized based on packet data
+func requestFromPacket(ctx context.Context, pkt hasPath) *Request {
+ method := requestMethod(pkt)
+ request := NewRequest(method, pkt.getPath())
+ request.ctx, request.cancelCtx = context.WithCancel(ctx)
+
+ switch p := pkt.(type) {
+ case *sshFxpOpenPacket:
+ request.Flags = p.Pflags
+ case *sshFxpSetstatPacket:
+ request.Flags = p.Flags
+ request.Attrs = p.Attrs.([]byte)
+ case *sshFxpRenamePacket:
+ request.Target = cleanPath(p.Newpath)
+ case *sshFxpSymlinkPacket:
+ // NOTE: given a POSIX compliant signature: symlink(target, linkpath string)
+ // this makes Request.Target the linkpath, and Request.Filepath the target.
+ request.Target = cleanPath(p.Linkpath)
+ case *sshFxpExtendedPacketHardlink:
+ request.Target = cleanPath(p.Newpath)
+ }
+ return request
+}
+
+// Context returns the request's context. To change the context,
+// use WithContext.
+//
+// The returned context is always non-nil; it defaults to the
+// background context.
+//
+// For incoming server requests, the context is canceled when the
+// request is complete or the client's connection closes.
+func (r *Request) Context() context.Context {
+ if r.ctx != nil {
+ return r.ctx
+ }
+ return context.Background()
+}
+
+// WithContext returns a copy of r with its context changed to ctx.
+// The provided ctx must be non-nil.
+func (r *Request) WithContext(ctx context.Context) *Request {
+ if ctx == nil {
+ panic("nil context")
+ }
+ r2 := r.copy()
+ r2.ctx = ctx
+ r2.cancelCtx = nil
+ return r2
+}
+
+// Close reader/writer if possible
+func (r *Request) close() error {
+ defer func() {
+ if r.cancelCtx != nil {
+ r.cancelCtx()
+ }
+ }()
+
+ rd, wr, rw := r.getAllReaderWriters()
+
+ var err error
+
+ // Close errors on a Writer are far more likely to be the important one.
+ // As they can be information that there was a loss of data.
+ if c, ok := wr.(io.Closer); ok {
+ if err2 := c.Close(); err == nil {
+ // update error if it is still nil
+ err = err2
+ }
+ }
+
+ if c, ok := rw.(io.Closer); ok {
+ if err2 := c.Close(); err == nil {
+ // update error if it is still nil
+ err = err2
+
+ r.setWriterAtReaderAt(nil)
+ }
+ }
+
+ if c, ok := rd.(io.Closer); ok {
+ if err2 := c.Close(); err == nil {
+ // update error if it is still nil
+ err = err2
+ }
+ }
+
+ return err
+}
+
+// Notify transfer error if any
+func (r *Request) transferError(err error) {
+ if err == nil {
+ return
+ }
+
+ rd, wr, rw := r.getAllReaderWriters()
+
+ if t, ok := wr.(TransferError); ok {
+ t.TransferError(err)
+ }
+
+ if t, ok := rw.(TransferError); ok {
+ t.TransferError(err)
+ }
+
+ if t, ok := rd.(TransferError); ok {
+ t.TransferError(err)
+ }
+}
+
+// called from worker to handle packet/request
+func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ switch r.Method {
+ case "Get":
+ return fileget(handlers.FileGet, r, pkt, alloc, orderID)
+ case "Put":
+ return fileput(handlers.FilePut, r, pkt, alloc, orderID)
+ case "Open":
+ return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
+ case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
+ return filecmd(handlers.FileCmd, r, pkt)
+ case "List":
+ return filelist(handlers.FileList, r, pkt)
+ case "Stat", "Lstat", "Readlink":
+ return filestat(handlers.FileList, r, pkt)
+ default:
+ return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method))
+ }
+}
+
+// Additional initialization for Open packets
+func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
+ flags := r.Pflags()
+
+ id := pkt.id()
+
+ switch {
+ case flags.Write, flags.Append, flags.Creat, flags.Trunc:
+ if flags.Read {
+ if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok {
+ r.Method = "Open"
+ rw, err := openFileWriter.OpenFile(r)
+ if err != nil {
+ return statusFromError(id, err)
+ }
+
+ r.setWriterAtReaderAt(rw)
+
+ return &sshFxpHandlePacket{
+ ID: id,
+ Handle: r.handle,
+ }
+ }
+ }
+
+ r.Method = "Put"
+ wr, err := h.FilePut.Filewrite(r)
+ if err != nil {
+ return statusFromError(id, err)
+ }
+
+ r.setWriterAt(wr)
+
+ case flags.Read:
+ r.Method = "Get"
+ rd, err := h.FileGet.Fileread(r)
+ if err != nil {
+ return statusFromError(id, err)
+ }
+
+ r.setReaderAt(rd)
+
+ default:
+ return statusFromError(id, errors.New("bad file flags"))
+ }
+
+ return &sshFxpHandlePacket{
+ ID: id,
+ Handle: r.handle,
+ }
+}
+
+func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
+ r.Method = "List"
+ la, err := h.FileList.Filelist(r)
+ if err != nil {
+ return statusFromError(pkt.id(), wrapPathError(r.Filepath, err))
+ }
+
+ r.setListerAt(la)
+
+ return &sshFxpHandlePacket{
+ ID: pkt.id(),
+ Handle: r.handle,
+ }
+}
+
+// wrap FileReader handler
+func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ rd := r.getReaderAt()
+ if rd == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected read packet"))
+ }
+
+ data, offset, _ := packetData(pkt, alloc, orderID)
+
+ n, err := rd.ReadAt(data, offset)
+ // only return EOF error if no data left to read
+ if err != nil && (err != io.EOF || n == 0) {
+ return statusFromError(pkt.id(), err)
+ }
+
+ return &sshFxpDataPacket{
+ ID: pkt.id(),
+ Length: uint32(n),
+ Data: data[:n],
+ }
+}
+
+// wrap FileWriter handler
+func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ wr := r.getWriterAt()
+ if wr == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected write packet"))
+ }
+
+ data, offset, _ := packetData(pkt, alloc, orderID)
+
+ _, err := wr.WriteAt(data, offset)
+ return statusFromError(pkt.id(), err)
+}
+
+// wrap OpenFileWriter handler
+func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+ rw := r.getWriterAtReaderAt()
+ if rw == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
+ }
+
+ switch p := pkt.(type) {
+ case *sshFxpReadPacket:
+ data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
+
+ n, err := rw.ReadAt(data, offset)
+ // only return EOF error if no data left to read
+ if err != nil && (err != io.EOF || n == 0) {
+ return statusFromError(pkt.id(), err)
+ }
+
+ return &sshFxpDataPacket{
+ ID: pkt.id(),
+ Length: uint32(n),
+ Data: data[:n],
+ }
+
+ case *sshFxpWritePacket:
+ data, offset := p.Data, int64(p.Offset)
+
+ _, err := rw.WriteAt(data, offset)
+ return statusFromError(pkt.id(), err)
+
+ default:
+ return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write"))
+ }
+}
+
+// file data for additional read/write packets
+func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
+ switch p := p.(type) {
+ case *sshFxpReadPacket:
+ return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
+ case *sshFxpWritePacket:
+ return p.Data, int64(p.Offset), p.Length
+ }
+ return
+}
+
+// wrap FileCmder handler
+func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
+ switch p := pkt.(type) {
+ case *sshFxpFsetstatPacket:
+ r.Flags = p.Flags
+ r.Attrs = p.Attrs.([]byte)
+ }
+
+ switch r.Method {
+ case "PosixRename":
+ if posixRenamer, ok := h.(PosixRenameFileCmder); ok {
+ err := posixRenamer.PosixRename(r)
+ return statusFromError(pkt.id(), err)
+ }
+
+ // PosixRenameFileCmder not implemented handle this request as a Rename
+ r.Method = "Rename"
+ err := h.Filecmd(r)
+ return statusFromError(pkt.id(), err)
+
+ case "StatVFS":
+ if statVFSCmdr, ok := h.(StatVFSFileCmder); ok {
+ stat, err := statVFSCmdr.StatVFS(r)
+ if err != nil {
+ return statusFromError(pkt.id(), err)
+ }
+ stat.ID = pkt.id()
+ return stat
+ }
+
+ return statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
+ }
+
+ err := h.Filecmd(r)
+ return statusFromError(pkt.id(), err)
+}
+
+// wrap FileLister handler
+func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket {
+ lister := r.getListerAt()
+ if lister == nil {
+ return statusFromError(pkt.id(), errors.New("unexpected dir packet"))
+ }
+
+ offset := r.lsNext()
+ finfo := make([]os.FileInfo, MaxFilelist)
+ n, err := lister.ListAt(finfo, offset)
+ r.lsInc(int64(n))
+ // ignore EOF as we only return it when there are no results
+ finfo = finfo[:n] // avoid need for nil tests below
+
+ switch r.Method {
+ case "List":
+ if err != nil && (err != io.EOF || n == 0) {
+ return statusFromError(pkt.id(), err)
+ }
+
+ nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo))
+
+ // If the type conversion fails, we get untyped `nil`,
+ // which is handled by not looking up any names.
+ idLookup, _ := h.(NameLookupFileLister)
+
+ for _, fi := range finfo {
+ nameAttrs = append(nameAttrs, &sshFxpNameAttr{
+ Name: fi.Name(),
+ LongName: runLs(idLookup, fi),
+ Attrs: []interface{}{fi},
+ })
+ }
+
+ return &sshFxpNamePacket{
+ ID: pkt.id(),
+ NameAttrs: nameAttrs,
+ }
+
+ default:
+ err = fmt.Errorf("unexpected method: %s", r.Method)
+ return statusFromError(pkt.id(), err)
+ }
+}
+
+func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket {
+ var lister ListerAt
+ var err error
+
+ if r.Method == "Lstat" {
+ if lstatFileLister, ok := h.(LstatFileLister); ok {
+ lister, err = lstatFileLister.Lstat(r)
+ } else {
+ // LstatFileLister not implemented handle this request as a Stat
+ r.Method = "Stat"
+ lister, err = h.Filelist(r)
+ }
+ } else {
+ lister, err = h.Filelist(r)
+ }
+ if err != nil {
+ return statusFromError(pkt.id(), err)
+ }
+ finfo := make([]os.FileInfo, 1)
+ n, err := lister.ListAt(finfo, 0)
+ finfo = finfo[:n] // avoid need for nil tests below
+
+ switch r.Method {
+ case "Stat", "Lstat":
+ if err != nil && err != io.EOF {
+ return statusFromError(pkt.id(), err)
+ }
+ if n == 0 {
+ err = &os.PathError{
+ Op: strings.ToLower(r.Method),
+ Path: r.Filepath,
+ Err: syscall.ENOENT,
+ }
+ return statusFromError(pkt.id(), err)
+ }
+ return &sshFxpStatResponse{
+ ID: pkt.id(),
+ info: finfo[0],
+ }
+ case "Readlink":
+ if err != nil && err != io.EOF {
+ return statusFromError(pkt.id(), err)
+ }
+ if n == 0 {
+ err = &os.PathError{
+ Op: "readlink",
+ Path: r.Filepath,
+ Err: syscall.ENOENT,
+ }
+ return statusFromError(pkt.id(), err)
+ }
+ filename := finfo[0].Name()
+ return &sshFxpNamePacket{
+ ID: pkt.id(),
+ NameAttrs: []*sshFxpNameAttr{
+ {
+ Name: filename,
+ LongName: filename,
+ Attrs: emptyFileStat,
+ },
+ },
+ }
+ default:
+ err = fmt.Errorf("unexpected method: %s", r.Method)
+ return statusFromError(pkt.id(), err)
+ }
+}
+
+// init attributes of request object from packet data
+func requestMethod(p requestPacket) (method string) {
+ switch p.(type) {
+ case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket:
+ // set in open() above
+ case *sshFxpOpendirPacket, *sshFxpReaddirPacket:
+ // set in opendir() above
+ case *sshFxpSetstatPacket, *sshFxpFsetstatPacket:
+ method = "Setstat"
+ case *sshFxpRenamePacket:
+ method = "Rename"
+ case *sshFxpSymlinkPacket:
+ method = "Symlink"
+ case *sshFxpRemovePacket:
+ method = "Remove"
+ case *sshFxpStatPacket, *sshFxpFstatPacket:
+ method = "Stat"
+ case *sshFxpLstatPacket:
+ method = "Lstat"
+ case *sshFxpRmdirPacket:
+ method = "Rmdir"
+ case *sshFxpReadlinkPacket:
+ method = "Readlink"
+ case *sshFxpMkdirPacket:
+ method = "Mkdir"
+ case *sshFxpExtendedPacketHardlink:
+ method = "Link"
+ }
+ return method
+}
diff --git a/sftp/request_windows.go b/sftp/request_windows.go
new file mode 100644
index 0000000..1f6d3df
--- /dev/null
+++ b/sftp/request_windows.go
@@ -0,0 +1,44 @@
+package sftp
+
+import (
+ "path"
+ "path/filepath"
+ "syscall"
+)
+
+func fakeFileInfoSys() interface{} {
+ return syscall.Win32FileAttributeData{}
+}
+
+func testOsSys(sys interface{}) error {
+ return nil
+}
+
+func toLocalPath(p string) string {
+ lp := filepath.FromSlash(p)
+
+ if path.IsAbs(p) {
+ tmp := lp
+ for len(tmp) > 0 && tmp[0] == '\\' {
+ tmp = tmp[1:]
+ }
+
+ if filepath.IsAbs(tmp) {
+ // If the FromSlash without any starting slashes is absolute,
+ // then we have a filepath encoded with a prefix '/'.
+ // e.g. "/C:/Windows" to "C:\\Windows"
+ return tmp
+ }
+
+ tmp += "\\"
+
+ if filepath.IsAbs(tmp) {
+ // If the FromSlash without any starting slashes but with extra end slash is absolute,
+ // then we have a filepath encoded with a prefix '/' and a dropped '/' at the end.
+ // e.g. "/C:" to "C:\\"
+ return tmp
+ }
+ }
+
+ return lp
+}
diff --git a/sftp/server.go b/sftp/server.go
new file mode 100644
index 0000000..6b2b20f
--- /dev/null
+++ b/sftp/server.go
@@ -0,0 +1,643 @@
+package sftp
+
+// sftp server counterpart
+
+import (
+ "context"
+ "log"
+ "encoding"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "strconv"
+ "sync"
+ "syscall"
+ "time"
+
+ "git.deuxfleurs.fr/Deuxfleurs/bagage/s3"
+)
+
+const (
+ // SftpServerWorkerCount defines the number of workers for the SFTP server
+ SftpServerWorkerCount = 8
+)
+
+// Server is an SSH File Transfer Protocol (sftp) server.
+// This is intended to provide the sftp subsystem to an ssh server daemon.
+// This implementation currently supports most of sftp server protocol version 3,
+// as specified at http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
+type Server struct {
+ *serverConn
+ debugStream io.Writer
+ readOnly bool
+ pktMgr *packetManager
+ openFiles map[string]*s3.S3File
+ openFilesLock sync.RWMutex
+ handleCount int
+ fs *s3.S3FS
+ ctx context.Context
+}
+
+func (svr *Server) nextHandle(f *s3.S3File) string {
+ svr.openFilesLock.Lock()
+ defer svr.openFilesLock.Unlock()
+ svr.handleCount++
+ handle := strconv.Itoa(svr.handleCount)
+ svr.openFiles[handle] = f
+ return handle
+}
+
+func (svr *Server) closeHandle(handle string) error {
+ svr.openFilesLock.Lock()
+ defer svr.openFilesLock.Unlock()
+ if f, ok := svr.openFiles[handle]; ok {
+ delete(svr.openFiles, handle)
+ return f.Close()
+ }
+
+ return EBADF
+}
+
+func (svr *Server) getHandle(handle string) (*s3.S3File, bool) {
+ svr.openFilesLock.RLock()
+ defer svr.openFilesLock.RUnlock()
+ f, ok := svr.openFiles[handle]
+ return f, ok
+}
+
+type serverRespondablePacket interface {
+ encoding.BinaryUnmarshaler
+ id() uint32
+ respond(svr *Server) responsePacket
+}
+
+// NewServer creates a new Server instance around the provided streams, serving
+// content from the root of the filesystem. Optionally, ServerOption
+// functions may be specified to further configure the Server.
+//
+// A subsequent call to Serve() is required to begin serving files over SFTP.
+func NewServer(ctx context.Context, rwc io.ReadWriteCloser, fs *s3.S3FS, options ...ServerOption) (*Server, error) {
+ svrConn := &serverConn{
+ conn: conn{
+ Reader: rwc,
+ WriteCloser: rwc,
+ },
+ }
+ s := &Server{
+ serverConn: svrConn,
+ debugStream: ioutil.Discard,
+ pktMgr: newPktMgr(svrConn),
+ openFiles: make(map[string]*s3.S3File),
+ fs: fs,
+ ctx: ctx,
+ }
+
+ for _, o := range options {
+ if err := o(s); err != nil {
+ return nil, err
+ }
+ }
+
+ return s, nil
+}
+
+// A ServerOption is a function which applies configuration to a Server.
+type ServerOption func(*Server) error
+
+// WithDebug enables Server debugging output to the supplied io.Writer.
+func WithDebug(w io.Writer) ServerOption {
+ return func(s *Server) error {
+ s.debugStream = w
+ return nil
+ }
+}
+
+// ReadOnly configures a Server to serve files in read-only mode.
+func ReadOnly() ServerOption {
+ return func(s *Server) error {
+ s.readOnly = true
+ return nil
+ }
+}
+
+// WithAllocator enable the allocator.
+// After processing a packet we keep in memory the allocated slices
+// and we reuse them for new packets.
+// The allocator is experimental
+func WithAllocator() ServerOption {
+ return func(s *Server) error {
+ alloc := newAllocator()
+ s.pktMgr.alloc = alloc
+ s.conn.alloc = alloc
+ return nil
+ }
+}
+
+type rxPacket struct {
+ pktType fxp
+ pktBytes []byte
+}
+
+// Up to N parallel servers
+func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error {
+ for pkt := range pktChan {
+ // readonly checks
+ readonly := true
+ switch pkt := pkt.requestPacket.(type) {
+ case notReadOnly:
+ readonly = false
+ case *sshFxpOpenPacket:
+ readonly = pkt.readonly()
+ case *sshFxpExtendedPacket:
+ readonly = pkt.readonly()
+ }
+
+ // If server is operating read-only and a write operation is requested,
+ // return permission denied
+ if !readonly && svr.readOnly {
+ svr.pktMgr.readyPacket(
+ svr.pktMgr.newOrderedResponse(statusFromError(pkt.id(), syscall.EPERM), pkt.orderID()),
+ )
+ continue
+ }
+
+ if err := handlePacket(svr, pkt); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func handlePacket(s *Server, p orderedRequest) error {
+ var rpkt responsePacket
+ orderID := p.orderID()
+ switch p := p.requestPacket.(type) {
+ case *sshFxInitPacket:
+ log.Println("pkt: init")
+ rpkt = &sshFxVersionPacket{
+ Version: sftpProtocolVersion,
+ Extensions: sftpExtensions,
+ }
+ case *sshFxpStatPacket:
+ log.Println("pkt: stat: ", p.Path)
+ // stat the requested file
+ info, err := os.Stat(toLocalPath(p.Path))
+ rpkt = &sshFxpStatResponse{
+ ID: p.ID,
+ info: info,
+ }
+ if err != nil {
+ rpkt = statusFromError(p.ID, err)
+ }
+ case *sshFxpLstatPacket:
+ log.Println("pkt: lstat: ", p.Path)
+ // stat the requested file
+ info, err := os.Lstat(toLocalPath(p.Path))
+ rpkt = &sshFxpStatResponse{
+ ID: p.ID,
+ info: info,
+ }
+ if err != nil {
+ rpkt = statusFromError(p.ID, err)
+ }
+ case *sshFxpFstatPacket:
+ log.Println("pkt: fstat: ", p.Handle)
+ f, ok := s.getHandle(p.Handle)
+ var err error = EBADF
+ var info os.FileInfo
+ if ok {
+ info, err = f.Stat()
+ rpkt = &sshFxpStatResponse{
+ ID: p.ID,
+ info: info,
+ }
+ }
+ if err != nil {
+ rpkt = statusFromError(p.ID, err)
+ }
+ case *sshFxpMkdirPacket:
+ log.Println("pkt: mkdir: ", p.Path)
+ err := os.Mkdir(toLocalPath(p.Path), 0755)
+ rpkt = statusFromError(p.ID, err)
+ case *sshFxpRmdirPacket:
+ log.Println("pkt: rmdir: ", p.Path)
+ err := os.Remove(toLocalPath(p.Path))
+ rpkt = statusFromError(p.ID, err)
+ case *sshFxpRemovePacket:
+ log.Println("pkt: rm: ", p.Filename)
+ err := os.Remove(toLocalPath(p.Filename))
+ rpkt = statusFromError(p.ID, err)
+ case *sshFxpRenamePacket:
+ log.Println("pkt: rename: ", p.Oldpath, ", ", p.Newpath)
+ err := os.Rename(toLocalPath(p.Oldpath), toLocalPath(p.Newpath))
+ rpkt = statusFromError(p.ID, err)
+ case *sshFxpSymlinkPacket:
+ log.Println("pkt: ln -s: ", p.Targetpath, ", ", p.Linkpath)
+ err := os.Symlink(toLocalPath(p.Targetpath), toLocalPath(p.Linkpath))
+ rpkt = statusFromError(p.ID, err)
+ case *sshFxpClosePacket:
+ log.Println("pkt: close handle: ", p.Handle)
+ rpkt = statusFromError(p.ID, s.closeHandle(p.Handle))
+ case *sshFxpReadlinkPacket:
+ log.Println("pkt: read: ", p.Path)
+ f, err := os.Readlink(toLocalPath(p.Path))
+ rpkt = &sshFxpNamePacket{
+ ID: p.ID,
+ NameAttrs: []*sshFxpNameAttr{
+ {
+ Name: f,
+ LongName: f,
+ Attrs: emptyFileStat,
+ },
+ },
+ }
+ if err != nil {
+ rpkt = statusFromError(p.ID, err)
+ }
+ case *sshFxpRealpathPacket:
+ log.Println("pkt: absolute path: ", p.Path)
+ f := s3.NewS3Path(p.Path).Path
+ rpkt = &sshFxpNamePacket{
+ ID: p.ID,
+ NameAttrs: []*sshFxpNameAttr{
+ {
+ Name: f,
+ LongName: f,
+ Attrs: emptyFileStat,
+ },
+ },
+ }
+ case *sshFxpOpendirPacket:
+ log.Println("pkt: open dir: ", p.Path)
+ p.Path = s3.NewS3Path(p.Path).Path
+
+ if stat, err := s.fs.Stat(s.ctx, p.Path); err != nil {
+ rpkt = statusFromError(p.ID, err)
+ } else if !stat.IsDir() {
+ rpkt = statusFromError(p.ID, &os.PathError{
+ Path: p.Path, Err: syscall.ENOTDIR})
+ } else {
+ rpkt = (&sshFxpOpenPacket{
+ ID: p.ID,
+ Path: p.Path,
+ Pflags: sshFxfRead,
+ }).respond(s)
+ }
+ case *sshFxpReadPacket:
+ log.Println("pkt: read handle: ", p.Handle)
+ var err error = EBADF
+ f, ok := s.getHandle(p.Handle)
+ if ok {
+ err = nil
+ data := p.getDataSlice(s.pktMgr.alloc, orderID)
+ n, _err := f.ReadAt(data, int64(p.Offset))
+ if _err != nil && (_err != io.EOF || n == 0) {
+ err = _err
+ }
+ rpkt = &sshFxpDataPacket{
+ ID: p.ID,
+ Length: uint32(n),
+ Data: data[:n],
+ // do not use data[:n:n] here to clamp the capacity, we allocated extra capacity above to avoid reallocations
+ }
+ }
+ if err != nil {
+ rpkt = statusFromError(p.ID, err)
+ }
+
+ case *sshFxpWritePacket:
+ log.Println("pkt: write handle: ", p.Handle, ", Offset: ", p.Offset)
+ f, ok := s.getHandle(p.Handle)
+ var err error = EBADF
+ if ok {
+ _, err = f.WriteAt(p.Data, int64(p.Offset))
+ }
+ rpkt = statusFromError(p.ID, err)
+ case *sshFxpExtendedPacket:
+ log.Println("pkt: extended packet")
+ if p.SpecificPacket == nil {
+ rpkt = statusFromError(p.ID, ErrSSHFxOpUnsupported)
+ } else {
+ rpkt = p.respond(s)
+ }
+ case serverRespondablePacket:
+ log.Println("pkt: respondable")
+ rpkt = p.respond(s)
+ default:
+ return fmt.Errorf("unexpected packet type %T", p)
+ }
+
+ s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, orderID))
+ return nil
+}
+
+// Serve serves SFTP connections until the streams stop or the SFTP subsystem
+// is stopped.
+func (svr *Server) Serve() error {
+ defer func() {
+ if svr.pktMgr.alloc != nil {
+ svr.pktMgr.alloc.Free()
+ }
+ }()
+ var wg sync.WaitGroup
+ runWorker := func(ch chan orderedRequest) {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := svr.sftpServerWorker(ch); err != nil {
+ svr.conn.Close() // shuts down recvPacket
+ }
+ }()
+ }
+ pktChan := svr.pktMgr.workerChan(runWorker)
+
+ var err error
+ var pkt requestPacket
+ var pktType uint8
+ var pktBytes []byte
+ for {
+ pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID())
+ if err != nil {
+ // we don't care about releasing allocated pages here, the server will quit and the allocator freed
+ break
+ }
+
+ pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
+ if err != nil {
+ switch {
+ case errors.Is(err, errUnknownExtendedPacket):
+ //if err := svr.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil {
+ // debug("failed to send err packet: %v", err)
+ // svr.conn.Close() // shuts down recvPacket
+ // break
+ //}
+ default:
+ debug("makePacket err: %v", err)
+ svr.conn.Close() // shuts down recvPacket
+ break
+ }
+ }
+
+ pktChan <- svr.pktMgr.newOrderedRequest(pkt)
+ }
+
+ close(pktChan) // shuts down sftpServerWorkers
+ wg.Wait() // wait for all workers to exit
+
+ // close any still-open files
+ for handle, file := range svr.openFiles {
+ fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Path.Path)
+ file.Close()
+ }
+ return err // error from recvPacket
+}
+
+type ider interface {
+ id() uint32
+}
+
+// The init packet has no ID, so we just return a zero-value ID
+func (p *sshFxInitPacket) id() uint32 { return 0 }
+
+type sshFxpStatResponse struct {
+ ID uint32
+ info os.FileInfo
+}
+
+func (p *sshFxpStatResponse) marshalPacket() ([]byte, []byte, error) {
+ l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(id)
+
+ b := make([]byte, 4, l)
+ b = append(b, sshFxpAttrs)
+ b = marshalUint32(b, p.ID)
+
+ var payload []byte
+ payload = marshalFileInfo(payload, p.info)
+
+ return b, payload, nil
+}
+
+func (p *sshFxpStatResponse) MarshalBinary() ([]byte, error) {
+ header, payload, err := p.marshalPacket()
+ return append(header, payload...), err
+}
+
+var emptyFileStat = []interface{}{uint32(0)}
+
+func (p *sshFxpOpenPacket) readonly() bool {
+ return !p.hasPflags(sshFxfWrite)
+}
+
+func (p *sshFxpOpenPacket) hasPflags(flags ...uint32) bool {
+ for _, f := range flags {
+ if p.Pflags&f == 0 {
+ return false
+ }
+ }
+ return true
+}
+
+func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
+ log.Println("pkt: open: ", p.Path)
+ var osFlags int
+ if p.hasPflags(sshFxfRead, sshFxfWrite) {
+ osFlags |= os.O_RDWR
+ } else if p.hasPflags(sshFxfWrite) {
+ osFlags |= os.O_WRONLY
+ } else if p.hasPflags(sshFxfRead) {
+ osFlags |= os.O_RDONLY
+ } else {
+ // how are they opening?
+ return statusFromError(p.ID, syscall.EINVAL)
+ }
+
+ // Don't use O_APPEND flag as it conflicts with WriteAt.
+ // The sshFxfAppend flag is a no-op here as the client sends the offsets.
+ // @FIXME these flags are currently ignored
+ if p.hasPflags(sshFxfCreat) {
+ osFlags |= os.O_CREATE
+ }
+ if p.hasPflags(sshFxfTrunc) {
+ osFlags |= os.O_TRUNC
+ }
+ if p.hasPflags(sshFxfExcl) {
+ osFlags |= os.O_EXCL
+ }
+
+ f, err := svr.fs.OpenFile2(svr.ctx, p.Path, osFlags, 0644)
+ if err != nil {
+ return statusFromError(p.ID, err)
+ }
+
+ handle := svr.nextHandle(f)
+ return &sshFxpHandlePacket{ID: p.ID, Handle: handle}
+}
+
+func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
+ log.Println("pkt: readdir: ", p.Handle)
+ f, ok := svr.getHandle(p.Handle)
+ if !ok {
+ return statusFromError(p.ID, EBADF)
+ }
+
+ dirents, err := f.Readdir(128)
+ if err != nil {
+ return statusFromError(p.ID, err)
+ }
+
+ idLookup := osIDLookup{}
+
+ ret := &sshFxpNamePacket{ID: p.ID}
+ for _, dirent := range dirents {
+ ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{
+ Name: dirent.Name(),
+ LongName: runLs(idLookup, dirent),
+ Attrs: []interface{}{dirent},
+ })
+ }
+ return ret
+}
+
+func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
+ log.Println("pkt: setstat: ", p.Path)
+ // additional unmarshalling is required for each possibility here
+ b := p.Attrs.([]byte)
+ var err error
+
+ p.Path = toLocalPath(p.Path)
+
+ debug("setstat name \"%s\"", p.Path)
+ if (p.Flags & sshFileXferAttrSize) != 0 {
+ var size uint64
+ if size, b, err = unmarshalUint64Safe(b); err == nil {
+ err = os.Truncate(p.Path, int64(size))
+ }
+ }
+ if (p.Flags & sshFileXferAttrPermissions) != 0 {
+ var mode uint32
+ if mode, b, err = unmarshalUint32Safe(b); err == nil {
+ err = os.Chmod(p.Path, os.FileMode(mode))
+ }
+ }
+ if (p.Flags & sshFileXferAttrACmodTime) != 0 {
+ var atime uint32
+ var mtime uint32
+ if atime, b, err = unmarshalUint32Safe(b); err != nil {
+ } else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
+ } else {
+ atimeT := time.Unix(int64(atime), 0)
+ mtimeT := time.Unix(int64(mtime), 0)
+ err = os.Chtimes(p.Path, atimeT, mtimeT)
+ }
+ }
+ if (p.Flags & sshFileXferAttrUIDGID) != 0 {
+ var uid uint32
+ var gid uint32
+ if uid, b, err = unmarshalUint32Safe(b); err != nil {
+ } else if gid, _, err = unmarshalUint32Safe(b); err != nil {
+ } else {
+ err = os.Chown(p.Path, int(uid), int(gid))
+ }
+ }
+
+ return statusFromError(p.ID, err)
+}
+
+func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
+ log.Println("pkt: fsetstat: ", p.Handle)
+ f, ok := svr.getHandle(p.Handle)
+ if !ok {
+ return statusFromError(p.ID, EBADF)
+ }
+
+ // additional unmarshalling is required for each possibility here
+ //b := p.Attrs.([]byte)
+ var err error
+
+ debug("fsetstat name \"%s\"", f.Path.Path)
+ if (p.Flags & sshFileXferAttrSize) != 0 {
+ /*var size uint64
+ if size, b, err = unmarshalUint64Safe(b); err == nil {
+ err = f.Truncate(int64(size))
+ }*/
+ log.Println("WARN: changing size of the file is not supported")
+ }
+ if (p.Flags & sshFileXferAttrPermissions) != 0 {
+ /*var mode uint32
+ if mode, b, err = unmarshalUint32Safe(b); err == nil {
+ err = f.Chmod(os.FileMode(mode))
+ }*/
+ log.Println("WARN: chmod not supported")
+ }
+ if (p.Flags & sshFileXferAttrACmodTime) != 0 {
+ /*var atime uint32
+ var mtime uint32
+ if atime, b, err = unmarshalUint32Safe(b); err != nil {
+ } else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
+ } else {
+ atimeT := time.Unix(int64(atime), 0)
+ mtimeT := time.Unix(int64(mtime), 0)
+ err = os.Chtimes(f.Name(), atimeT, mtimeT)
+ }*/
+ log.Println("WARN: chtimes not supported")
+ }
+ if (p.Flags & sshFileXferAttrUIDGID) != 0 {
+ /*var uid uint32
+ var gid uint32
+ if uid, b, err = unmarshalUint32Safe(b); err != nil {
+ } else if gid, _, err = unmarshalUint32Safe(b); err != nil {
+ } else {
+ err = f.Chown(int(uid), int(gid))
+ }*/
+ log.Println("WARN: chown not supported")
+ }
+
+ return statusFromError(p.ID, err)
+}
+
+func statusFromError(id uint32, err error) *sshFxpStatusPacket {
+ ret := &sshFxpStatusPacket{
+ ID: id,
+ StatusError: StatusError{
+ // sshFXOk = 0
+ // sshFXEOF = 1
+ // sshFXNoSuchFile = 2 ENOENT
+ // sshFXPermissionDenied = 3
+ // sshFXFailure = 4
+ // sshFXBadMessage = 5
+ // sshFXNoConnection = 6
+ // sshFXConnectionLost = 7
+ // sshFXOPUnsupported = 8
+ Code: sshFxOk,
+ },
+ }
+ if err == nil {
+ return ret
+ }
+
+ debug("statusFromError: error is %T %#v", err, err)
+ ret.StatusError.Code = sshFxFailure
+ ret.StatusError.msg = err.Error()
+
+ if os.IsNotExist(err) {
+ ret.StatusError.Code = sshFxNoSuchFile
+ return ret
+ }
+ if code, ok := translateSyscallError(err); ok {
+ ret.StatusError.Code = code
+ return ret
+ }
+
+ switch e := err.(type) {
+ case fxerr:
+ ret.StatusError.Code = uint32(e)
+ default:
+ if e == io.EOF {
+ ret.StatusError.Code = sshFxEOF
+ }
+ }
+
+ return ret
+}
diff --git a/sftp/server_statvfs_darwin.go b/sftp/server_statvfs_darwin.go
new file mode 100644
index 0000000..8c01dac
--- /dev/null
+++ b/sftp/server_statvfs_darwin.go
@@ -0,0 +1,21 @@
+package sftp
+
+import (
+ "syscall"
+)
+
+func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) {
+ return &StatVFS{
+ Bsize: uint64(stat.Bsize),
+ Frsize: uint64(stat.Bsize), // fragment size is a linux thing; use block size here
+ Blocks: stat.Blocks,
+ Bfree: stat.Bfree,
+ Bavail: stat.Bavail,
+ Files: stat.Files,
+ Ffree: stat.Ffree,
+ Favail: stat.Ffree, // not sure how to calculate Favail
+ Fsid: uint64(uint64(stat.Fsid.Val[1])<<32 | uint64(stat.Fsid.Val[0])), // endianness?
+ Flag: uint64(stat.Flags), // assuming POSIX?
+ Namemax: 1024, // man 2 statfs shows: #define MAXPATHLEN 1024
+ }, nil
+}
diff --git a/sftp/server_statvfs_impl.go b/sftp/server_statvfs_impl.go
new file mode 100644
index 0000000..94b6d83
--- /dev/null
+++ b/sftp/server_statvfs_impl.go
@@ -0,0 +1,29 @@
+// +build darwin linux
+
+// fill in statvfs structure with OS specific values
+// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance)
+
+package sftp
+
+import (
+ "syscall"
+)
+
+func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
+ retPkt, err := getStatVFSForPath(p.Path)
+ if err != nil {
+ return statusFromError(p.ID, err)
+ }
+ retPkt.ID = p.ID
+
+ return retPkt
+}
+
+func getStatVFSForPath(name string) (*StatVFS, error) {
+ var stat syscall.Statfs_t
+ if err := syscall.Statfs(name, &stat); err != nil {
+ return nil, err
+ }
+
+ return statvfsFromStatfst(&stat)
+}
diff --git a/sftp/server_statvfs_linux.go b/sftp/server_statvfs_linux.go
new file mode 100644
index 0000000..1d180d4
--- /dev/null
+++ b/sftp/server_statvfs_linux.go
@@ -0,0 +1,22 @@
+// +build linux
+
+package sftp
+
+import (
+ "syscall"
+)
+
+func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) {
+ return &StatVFS{
+ Bsize: uint64(stat.Bsize),
+ Frsize: uint64(stat.Frsize),
+ Blocks: stat.Blocks,
+ Bfree: stat.Bfree,
+ Bavail: stat.Bavail,
+ Files: stat.Files,
+ Ffree: stat.Ffree,
+ Favail: stat.Ffree, // not sure how to calculate Favail
+ Flag: uint64(stat.Flags), // assuming POSIX?
+ Namemax: uint64(stat.Namelen),
+ }, nil
+}
diff --git a/sftp/server_statvfs_plan9.go b/sftp/server_statvfs_plan9.go
new file mode 100644
index 0000000..e71a27d
--- /dev/null
+++ b/sftp/server_statvfs_plan9.go
@@ -0,0 +1,13 @@
+package sftp
+
+import (
+ "syscall"
+)
+
+func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
+ return statusFromError(p.ID, syscall.EPLAN9)
+}
+
+func getStatVFSForPath(name string) (*StatVFS, error) {
+ return nil, syscall.EPLAN9
+}
diff --git a/sftp/server_statvfs_stubs.go b/sftp/server_statvfs_stubs.go
new file mode 100644
index 0000000..fbf4906
--- /dev/null
+++ b/sftp/server_statvfs_stubs.go
@@ -0,0 +1,15 @@
+// +build !darwin,!linux,!plan9
+
+package sftp
+
+import (
+ "syscall"
+)
+
+func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket {
+ return statusFromError(p.ID, syscall.ENOTSUP)
+}
+
+func getStatVFSForPath(name string) (*StatVFS, error) {
+ return nil, syscall.ENOTSUP
+}
diff --git a/sftp/sftp.go b/sftp/sftp.go
new file mode 100644
index 0000000..9a63c39
--- /dev/null
+++ b/sftp/sftp.go
@@ -0,0 +1,258 @@
+// Package sftp implements the SSH File Transfer Protocol as described in
+// https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
+package sftp
+
+import (
+ "fmt"
+)
+
+const (
+ sshFxpInit = 1
+ sshFxpVersion = 2
+ sshFxpOpen = 3
+ sshFxpClose = 4
+ sshFxpRead = 5
+ sshFxpWrite = 6
+ sshFxpLstat = 7
+ sshFxpFstat = 8
+ sshFxpSetstat = 9
+ sshFxpFsetstat = 10
+ sshFxpOpendir = 11
+ sshFxpReaddir = 12
+ sshFxpRemove = 13
+ sshFxpMkdir = 14
+ sshFxpRmdir = 15
+ sshFxpRealpath = 16
+ sshFxpStat = 17
+ sshFxpRename = 18
+ sshFxpReadlink = 19
+ sshFxpSymlink = 20
+ sshFxpStatus = 101
+ sshFxpHandle = 102
+ sshFxpData = 103
+ sshFxpName = 104
+ sshFxpAttrs = 105
+ sshFxpExtended = 200
+ sshFxpExtendedReply = 201
+)
+
+const (
+ sshFxOk = 0
+ sshFxEOF = 1
+ sshFxNoSuchFile = 2
+ sshFxPermissionDenied = 3
+ sshFxFailure = 4
+ sshFxBadMessage = 5
+ sshFxNoConnection = 6
+ sshFxConnectionLost = 7
+ sshFxOPUnsupported = 8
+
+ // see draft-ietf-secsh-filexfer-13
+ // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1
+ sshFxInvalidHandle = 9
+ sshFxNoSuchPath = 10
+ sshFxFileAlreadyExists = 11
+ sshFxWriteProtect = 12
+ sshFxNoMedia = 13
+ sshFxNoSpaceOnFilesystem = 14
+ sshFxQuotaExceeded = 15
+ sshFxUnknownPrincipal = 16
+ sshFxLockConflict = 17
+ sshFxDirNotEmpty = 18
+ sshFxNotADirectory = 19
+ sshFxInvalidFilename = 20
+ sshFxLinkLoop = 21
+ sshFxCannotDelete = 22
+ sshFxInvalidParameter = 23
+ sshFxFileIsADirectory = 24
+ sshFxByteRangeLockConflict = 25
+ sshFxByteRangeLockRefused = 26
+ sshFxDeletePending = 27
+ sshFxFileCorrupt = 28
+ sshFxOwnerInvalid = 29
+ sshFxGroupInvalid = 30
+ sshFxNoMatchingByteRangeLock = 31
+)
+
+const (
+ sshFxfRead = 0x00000001
+ sshFxfWrite = 0x00000002
+ sshFxfAppend = 0x00000004
+ sshFxfCreat = 0x00000008
+ sshFxfTrunc = 0x00000010
+ sshFxfExcl = 0x00000020
+)
+
+var (
+ // supportedSFTPExtensions defines the supported extensions
+ supportedSFTPExtensions = []sshExtensionPair{
+ {"hardlink@openssh.com", "1"},
+ {"posix-rename@openssh.com", "1"},
+ {"statvfs@openssh.com", "2"},
+ }
+ sftpExtensions = supportedSFTPExtensions
+)
+
+type fxp uint8
+
+func (f fxp) String() string {
+ switch f {
+ case sshFxpInit:
+ return "SSH_FXP_INIT"
+ case sshFxpVersion:
+ return "SSH_FXP_VERSION"
+ case sshFxpOpen:
+ return "SSH_FXP_OPEN"
+ case sshFxpClose:
+ return "SSH_FXP_CLOSE"
+ case sshFxpRead:
+ return "SSH_FXP_READ"
+ case sshFxpWrite:
+ return "SSH_FXP_WRITE"
+ case sshFxpLstat:
+ return "SSH_FXP_LSTAT"
+ case sshFxpFstat:
+ return "SSH_FXP_FSTAT"
+ case sshFxpSetstat:
+ return "SSH_FXP_SETSTAT"
+ case sshFxpFsetstat:
+ return "SSH_FXP_FSETSTAT"
+ case sshFxpOpendir:
+ return "SSH_FXP_OPENDIR"
+ case sshFxpReaddir:
+ return "SSH_FXP_READDIR"
+ case sshFxpRemove:
+ return "SSH_FXP_REMOVE"
+ case sshFxpMkdir:
+ return "SSH_FXP_MKDIR"
+ case sshFxpRmdir:
+ return "SSH_FXP_RMDIR"
+ case sshFxpRealpath:
+ return "SSH_FXP_REALPATH"
+ case sshFxpStat:
+ return "SSH_FXP_STAT"
+ case sshFxpRename:
+ return "SSH_FXP_RENAME"
+ case sshFxpReadlink:
+ return "SSH_FXP_READLINK"
+ case sshFxpSymlink:
+ return "SSH_FXP_SYMLINK"
+ case sshFxpStatus:
+ return "SSH_FXP_STATUS"
+ case sshFxpHandle:
+ return "SSH_FXP_HANDLE"
+ case sshFxpData:
+ return "SSH_FXP_DATA"
+ case sshFxpName:
+ return "SSH_FXP_NAME"
+ case sshFxpAttrs:
+ return "SSH_FXP_ATTRS"
+ case sshFxpExtended:
+ return "SSH_FXP_EXTENDED"
+ case sshFxpExtendedReply:
+ return "SSH_FXP_EXTENDED_REPLY"
+ default:
+ return "unknown"
+ }
+}
+
+type fx uint8
+
+func (f fx) String() string {
+ switch f {
+ case sshFxOk:
+ return "SSH_FX_OK"
+ case sshFxEOF:
+ return "SSH_FX_EOF"
+ case sshFxNoSuchFile:
+ return "SSH_FX_NO_SUCH_FILE"
+ case sshFxPermissionDenied:
+ return "SSH_FX_PERMISSION_DENIED"
+ case sshFxFailure:
+ return "SSH_FX_FAILURE"
+ case sshFxBadMessage:
+ return "SSH_FX_BAD_MESSAGE"
+ case sshFxNoConnection:
+ return "SSH_FX_NO_CONNECTION"
+ case sshFxConnectionLost:
+ return "SSH_FX_CONNECTION_LOST"
+ case sshFxOPUnsupported:
+ return "SSH_FX_OP_UNSUPPORTED"
+ default:
+ return "unknown"
+ }
+}
+
+type unexpectedPacketErr struct {
+ want, got uint8
+}
+
+func (u *unexpectedPacketErr) Error() string {
+ return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got))
+}
+
+func unimplementedPacketErr(u uint8) error {
+ return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u))
+}
+
+type unexpectedIDErr struct{ want, got uint32 }
+
+func (u *unexpectedIDErr) Error() string {
+ return fmt.Sprintf("sftp: unexpected id: want %d, got %d", u.want, u.got)
+}
+
+func unimplementedSeekWhence(whence int) error {
+ return fmt.Errorf("sftp: unimplemented seek whence %d", whence)
+}
+
+func unexpectedCount(want, got uint32) error {
+ return fmt.Errorf("sftp: unexpected count: want %d, got %d", want, got)
+}
+
+type unexpectedVersionErr struct{ want, got uint32 }
+
+func (u *unexpectedVersionErr) Error() string {
+ return fmt.Sprintf("sftp: unexpected server version: want %v, got %v", u.want, u.got)
+}
+
+// A StatusError is returned when an SFTP operation fails, and provides
+// additional information about the failure.
+type StatusError struct {
+ Code uint32
+ msg, lang string
+}
+
+func (s *StatusError) Error() string {
+ return fmt.Sprintf("sftp: %q (%v)", s.msg, fx(s.Code))
+}
+
+// FxCode returns the error code typed to match against the exported codes
+func (s *StatusError) FxCode() fxerr {
+ return fxerr(s.Code)
+}
+
+func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) {
+ for _, supportedExtension := range supportedSFTPExtensions {
+ if supportedExtension.Name == extensionName {
+ return supportedExtension, nil
+ }
+ }
+ return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName)
+}
+
+// SetSFTPExtensions allows to customize the supported server extensions.
+// See the variable supportedSFTPExtensions for supported extensions.
+// This method accepts a slice of sshExtensionPair names for example 'hardlink@openssh.com'.
+// If an invalid extension is given an error will be returned and nothing will be changed
+func SetSFTPExtensions(extensions ...string) error {
+ tempExtensions := []sshExtensionPair{}
+ for _, extension := range extensions {
+ sftpExtension, err := getSupportedExtensionByName(extension)
+ if err != nil {
+ return err
+ }
+ tempExtensions = append(tempExtensions, sftpExtension)
+ }
+ sftpExtensions = tempExtensions
+ return nil
+}
diff --git a/sftp/stat_plan9.go b/sftp/stat_plan9.go
new file mode 100644
index 0000000..761abdf
--- /dev/null
+++ b/sftp/stat_plan9.go
@@ -0,0 +1,103 @@
+package sftp
+
+import (
+ "os"
+ "syscall"
+)
+
+var EBADF = syscall.NewError("fd out of range or not open")
+
+func wrapPathError(filepath string, err error) error {
+ if errno, ok := err.(syscall.ErrorString); ok {
+ return &os.PathError{Path: filepath, Err: errno}
+ }
+ return err
+}
+
+// translateErrno translates a syscall error number to a SFTP error code.
+func translateErrno(errno syscall.ErrorString) uint32 {
+ switch errno {
+ case "":
+ return sshFxOk
+ case syscall.ENOENT:
+ return sshFxNoSuchFile
+ case syscall.EPERM:
+ return sshFxPermissionDenied
+ }
+
+ return sshFxFailure
+}
+
+func translateSyscallError(err error) (uint32, bool) {
+ switch e := err.(type) {
+ case syscall.ErrorString:
+ return translateErrno(e), true
+ case *os.PathError:
+ debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
+ if errno, ok := e.Err.(syscall.ErrorString); ok {
+ return translateErrno(errno), true
+ }
+ }
+ return 0, false
+}
+
+// isRegular returns true if the mode describes a regular file.
+func isRegular(mode uint32) bool {
+ return mode&S_IFMT == syscall.S_IFREG
+}
+
+// toFileMode converts sftp filemode bits to the os.FileMode specification
+func toFileMode(mode uint32) os.FileMode {
+ var fm = os.FileMode(mode & 0777)
+
+ switch mode & S_IFMT {
+ case syscall.S_IFBLK:
+ fm |= os.ModeDevice
+ case syscall.S_IFCHR:
+ fm |= os.ModeDevice | os.ModeCharDevice
+ case syscall.S_IFDIR:
+ fm |= os.ModeDir
+ case syscall.S_IFIFO:
+ fm |= os.ModeNamedPipe
+ case syscall.S_IFLNK:
+ fm |= os.ModeSymlink
+ case syscall.S_IFREG:
+ // nothing to do
+ case syscall.S_IFSOCK:
+ fm |= os.ModeSocket
+ }
+
+ return fm
+}
+
+// fromFileMode converts from the os.FileMode specification to sftp filemode bits
+func fromFileMode(mode os.FileMode) uint32 {
+ ret := uint32(mode & os.ModePerm)
+
+ switch mode & os.ModeType {
+ case os.ModeDevice | os.ModeCharDevice:
+ ret |= syscall.S_IFCHR
+ case os.ModeDevice:
+ ret |= syscall.S_IFBLK
+ case os.ModeDir:
+ ret |= syscall.S_IFDIR
+ case os.ModeNamedPipe:
+ ret |= syscall.S_IFIFO
+ case os.ModeSymlink:
+ ret |= syscall.S_IFLNK
+ case 0:
+ ret |= syscall.S_IFREG
+ case os.ModeSocket:
+ ret |= syscall.S_IFSOCK
+ }
+
+ return ret
+}
+
+// Plan 9 doesn't have setuid, setgid or sticky, but a Plan 9 client should
+// be able to send these bits to a POSIX server.
+const (
+ s_ISUID = 04000
+ s_ISGID = 02000
+ s_ISVTX = 01000
+)
diff --git a/sftp/stat_posix.go b/sftp/stat_posix.go
new file mode 100644
index 0000000..5b870e2
--- /dev/null
+++ b/sftp/stat_posix.go
@@ -0,0 +1,124 @@
+//go:build !plan9
+// +build !plan9
+
+package sftp
+
+import (
+ "os"
+ "syscall"
+)
+
+const EBADF = syscall.EBADF
+
+func wrapPathError(filepath string, err error) error {
+ if errno, ok := err.(syscall.Errno); ok {
+ return &os.PathError{Path: filepath, Err: errno}
+ }
+ return err
+}
+
+// translateErrno translates a syscall error number to a SFTP error code.
+func translateErrno(errno syscall.Errno) uint32 {
+ switch errno {
+ case 0:
+ return sshFxOk
+ case syscall.ENOENT:
+ return sshFxNoSuchFile
+ case syscall.EACCES, syscall.EPERM:
+ return sshFxPermissionDenied
+ }
+
+ return sshFxFailure
+}
+
+func translateSyscallError(err error) (uint32, bool) {
+ switch e := err.(type) {
+ case syscall.Errno:
+ return translateErrno(e), true
+ case *os.PathError:
+ debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
+ if errno, ok := e.Err.(syscall.Errno); ok {
+ return translateErrno(errno), true
+ }
+ }
+ return 0, false
+}
+
+// isRegular returns true if the mode describes a regular file.
+func isRegular(mode uint32) bool {
+ return mode&S_IFMT == syscall.S_IFREG
+}
+
+// toFileMode converts sftp filemode bits to the os.FileMode specification
+func toFileMode(mode uint32) os.FileMode {
+ var fm = os.FileMode(mode & 0777)
+
+ switch mode & S_IFMT {
+ case syscall.S_IFBLK:
+ fm |= os.ModeDevice
+ case syscall.S_IFCHR:
+ fm |= os.ModeDevice | os.ModeCharDevice
+ case syscall.S_IFDIR:
+ fm |= os.ModeDir
+ case syscall.S_IFIFO:
+ fm |= os.ModeNamedPipe
+ case syscall.S_IFLNK:
+ fm |= os.ModeSymlink
+ case syscall.S_IFREG:
+ // nothing to do
+ case syscall.S_IFSOCK:
+ fm |= os.ModeSocket
+ }
+
+ if mode&syscall.S_ISUID != 0 {
+ fm |= os.ModeSetuid
+ }
+ if mode&syscall.S_ISGID != 0 {
+ fm |= os.ModeSetgid
+ }
+ if mode&syscall.S_ISVTX != 0 {
+ fm |= os.ModeSticky
+ }
+
+ return fm
+}
+
+// fromFileMode converts from the os.FileMode specification to sftp filemode bits
+func fromFileMode(mode os.FileMode) uint32 {
+ ret := uint32(mode & os.ModePerm)
+
+ switch mode & os.ModeType {
+ case os.ModeDevice | os.ModeCharDevice:
+ ret |= syscall.S_IFCHR
+ case os.ModeDevice:
+ ret |= syscall.S_IFBLK
+ case os.ModeDir:
+ ret |= syscall.S_IFDIR
+ case os.ModeNamedPipe:
+ ret |= syscall.S_IFIFO
+ case os.ModeSymlink:
+ ret |= syscall.S_IFLNK
+ case 0:
+ ret |= syscall.S_IFREG
+ case os.ModeSocket:
+ ret |= syscall.S_IFSOCK
+ }
+
+ if mode&os.ModeSetuid != 0 {
+ ret |= syscall.S_ISUID
+ }
+ if mode&os.ModeSetgid != 0 {
+ ret |= syscall.S_ISGID
+ }
+ if mode&os.ModeSticky != 0 {
+ ret |= syscall.S_ISVTX
+ }
+
+ return ret
+}
+
+const (
+ s_ISUID = syscall.S_ISUID
+ s_ISGID = syscall.S_ISGID
+ s_ISVTX = syscall.S_ISVTX
+)
diff --git a/sftp/syscall_fixed.go b/sftp/syscall_fixed.go
new file mode 100644
index 0000000..d404577
--- /dev/null
+++ b/sftp/syscall_fixed.go
@@ -0,0 +1,9 @@
+// +build plan9 windows js,wasm
+
+// Go defines S_IFMT on windows, plan9 and js/wasm as 0x1f000 instead of
+// 0xf000. None of the the other S_IFxyz values include the "1" (in 0x1f000)
+// which prevents them from matching the bitmask.
+
+package sftp
+
+const S_IFMT = 0xf000
diff --git a/sftp/syscall_good.go b/sftp/syscall_good.go
new file mode 100644
index 0000000..4c2b240
--- /dev/null
+++ b/sftp/syscall_good.go
@@ -0,0 +1,8 @@
+// +build !plan9,!windows
+// +build !js !wasm
+
+package sftp
+
+import "syscall"
+
+const S_IFMT = syscall.S_IFMT
diff --git a/webdav.go b/webdav.go
index 3f3f7d5..8f55b44 100644
--- a/webdav.go
+++ b/webdav.go
@@ -2,6 +2,7 @@ package main
import (
"github.com/minio/minio-go/v7"
+ "git.deuxfleurs.fr/Deuxfleurs/bagage/s3"
"golang.org/x/net/webdav"
"log"
"net/http"
@@ -15,7 +16,7 @@ func (wd WebDav) WithMC(mc *minio.Client) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
(&webdav.Handler{
Prefix: wd.WithConfig.DavPath,
- FileSystem: NewS3FS(mc),
+ FileSystem: s3.NewS3FS(mc),
LockSystem: webdav.NewMemLS(),
Logger: func(r *http.Request, err error) {
log.Printf("INFO: %s %s %s\n", r.RemoteAddr, r.Method, r.URL)