diff --git a/server.go b/server.go index 2e419f59..0c990be8 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "io/fs" "io/ioutil" "os" "path/filepath" @@ -21,6 +22,51 @@ const ( SftpServerWorkerCount = 8 ) +type FileLike interface { + Stat() (os.FileInfo, error) + ReadAt(b []byte, off int64) (int, error) + WriteAt(b []byte, off int64) (int, error) + Readdir(int) ([]os.FileInfo, error) + Name() string + Truncate(int64) error + Chmod(mode fs.FileMode) error + Chown(uid, gid int) error + Close() error +} + +type dummyFile struct { +} + +func (f *dummyFile) Stat() (os.FileInfo, error) { + return nil, os.ErrPermission +} +func (f *dummyFile) ReadAt(b []byte, off int64) (int, error) { + return 0, os.ErrPermission +} +func (f *dummyFile) WriteAt(b []byte, off int64) (int, error) { + return 0, os.ErrPermission +} +func (f *dummyFile) Readdir(int) ([]os.FileInfo, error) { + return nil, os.ErrPermission +} +func (f *dummyFile) Name() string { + return "dummyFile" +} +func (f *dummyFile) Truncate(int64) error { + return os.ErrPermission +} +func (f *dummyFile) Chmod(mode fs.FileMode) error { + return os.ErrPermission +} +func (f *dummyFile) Chown(uid, gid int) error { + return os.ErrPermission +} +func (f *dummyFile) Close() error { + return os.ErrPermission +} + +var _ = dummyFile{} // ignore unused + // Server is an SSH File Transfer Protocol (sftp) server. // This is intended to provide the sftp subsystem to an ssh server daemon. // This implementation currently supports most of sftp server protocol version 3, @@ -30,13 +76,13 @@ type Server struct { debugStream io.Writer readOnly bool pktMgr *packetManager - openFiles map[string]*os.File + openFiles map[string]FileLike openFilesLock sync.RWMutex handleCount int workDir string } -func (svr *Server) nextHandle(f *os.File) string { +func (svr *Server) nextHandle(f FileLike) string { svr.openFilesLock.Lock() defer svr.openFilesLock.Unlock() svr.handleCount++ @@ -56,7 +102,7 @@ func (svr *Server) closeHandle(handle string) error { return EBADF } -func (svr *Server) getHandle(handle string) (*os.File, bool) { +func (svr *Server) getHandle(handle string) (FileLike, bool) { svr.openFilesLock.RLock() defer svr.openFilesLock.RUnlock() f, ok := svr.openFiles[handle] @@ -85,7 +131,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) serverConn: svrConn, debugStream: ioutil.Discard, pktMgr: newPktMgr(svrConn), - openFiles: make(map[string]*os.File), + openFiles: make(map[string]FileLike), } for _, o := range options { @@ -462,7 +508,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { osFlags |= os.O_EXCL } - f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644) + f, err := openFileLike(svr.toLocalPath(p.Path), osFlags, 0o644) if err != nil { return statusFromError(p.ID, err) } diff --git a/server_posix.go b/server_posix.go new file mode 100644 index 00000000..0fbb056c --- /dev/null +++ b/server_posix.go @@ -0,0 +1,13 @@ +//go:build !windows +// +build !windows + +package sftp + +import ( + "io/fs" + "os" +) + +func openFileLike(path string, flag int, mode fs.FileMode) (FileLike, error) { + return os.OpenFile(path, flag, mode) +} diff --git a/server_windows.go b/server_windows.go index b35be730..68bf4a87 100644 --- a/server_windows.go +++ b/server_windows.go @@ -1,8 +1,15 @@ +//go:build go1.18 + package sftp import ( + "fmt" + "io/fs" + "os" "path" "path/filepath" + "syscall" + "time" ) func (s *Server) toLocalPath(p string) string { @@ -37,3 +44,76 @@ func (s *Server) toLocalPath(p string) string { return lp } + +var kernel32, _ = syscall.LoadLibrary("kernel32.dll") +var getLogicalDrivesHandle, _ = syscall.GetProcAddress(kernel32, "GetLogicalDrives") + +func bitsToDrives(bitmap uint32) []string { + var drive rune = 'A' + var drives []string + + for bitmap != 0 { + if bitmap&1 == 1 { + drives = append(drives, string(drive)) + } + drive++ + bitmap >>= 1 + } + + return drives +} + +func getDrives() ([]string, error) { + if ret, _, callErr := syscall.Syscall(uintptr(getLogicalDrivesHandle), 0, 0, 0, 0); callErr != 0 { + return nil, fmt.Errorf("GetLogicalDrives: %w", callErr) + } else { + drives := bitsToDrives(uint32(ret)) + return drives, nil + } +} + +type dummyDriveStat struct { + name string +} + +func (s *dummyDriveStat) Name() string { + return s.name +} +func (s *dummyDriveStat) Size() int64 { + return 1024 +} +func (s *dummyDriveStat) Mode() os.FileMode { + return os.FileMode(0755) +} +func (s *dummyDriveStat) ModTime() time.Time { + return time.Now() +} +func (s *dummyDriveStat) IsDir() bool { + return true +} +func (s *dummyDriveStat) Sys() any { + return nil +} + +type WinRoot struct { + dummyFile +} + +func (f *WinRoot) Readdir(int) ([]os.FileInfo, error) { + drives, err := getDrives() + if err != nil { + return nil, err + } + infos := []os.FileInfo{} + for _, drive := range drives { + infos = append(infos, &dummyDriveStat{drive}) + } + return infos, nil +} + +func openFileLike(path string, flag int, mode fs.FileMode) (FileLike, error) { + if path == "/" { + return &WinRoot{}, nil + } + return os.OpenFile(path, flag, mode) +}