diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..0994b06 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,37 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + test: + name: go-test + runs-on: windows-latest + env: + CGO_ENABLED: 0 + steps: + - name: disable-auto-crlf + run: | + git config --global core.autocrlf false + git config --global core.eol lf + + - name: clone-repo + uses: actions/checkout@v4 + + - name: setup-go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: go-vet-fmt-test # fmt check see Test_Gofmt + run : | + go vet + go test -v -timeout 120s -tags "-race" ./... + + diff --git a/adapter.go b/adapter.go index 4efed18..8164189 100644 --- a/adapter.go +++ b/adapter.go @@ -6,6 +6,7 @@ package wintun import ( "context" "sync" + "syscall" "github.com/pkg/errors" @@ -30,7 +31,11 @@ func (a *Adapter) sessionLocked(trap uintptr, args ...uintptr) (r1, r2 uintptr, } else if a.session == 0 { return 0, 0, errors.WithStack(ErrAdapterStoped{}) } - return global.calln(trap, append([]uintptr{a.session}, args...)...) + r1, r2, err = syscall.SyscallN(trap, append([]uintptr{a.session}, args...)...) + if err == windows.ERROR_SUCCESS { + err = nil + } + return r1, r2, errors.WithStack(err) } func (a *Adapter) Start(capacity uint32) (err error) { @@ -43,12 +48,12 @@ func (a *Adapter) Start(capacity uint32) (err error) { if a.handle == 0 { return errors.WithStack(ErrAdapterClosed{}) } - fd, _, err := global.calln( - global.procStartSession, + fd, _, err := syscall.SyscallN( + procStartSession.Addr(), a.handle, uintptr(capacity), ) - if err != nil { + if err != windows.ERROR_SUCCESS { return err } a.session = fd @@ -63,8 +68,8 @@ func (a *Adapter) Stop() error { func (a *Adapter) stopLocked() error { if a.session > 0 { - _, _, err := global.calln(global.procEndSession, uintptr(a.session)) - if err != nil { + _, _, err := syscall.SyscallN(procEndSession.Addr(), uintptr(a.session)) + if err != windows.ERROR_SUCCESS { return err } a.session = 0 @@ -82,8 +87,8 @@ func (a *Adapter) Close() error { return err } - _, _, err = global.calln(global.procCloseAdapter, a.handle) - if err != nil { + _, _, err = syscall.SyscallN(procCloseAdapter.Addr(), a.handle) + if err != windows.ERROR_SUCCESS { return err } a.handle = 0 @@ -99,12 +104,12 @@ func (a *Adapter) GetAdapterLuid() (winipcfg.LUID, error) { return 0, errors.WithStack(ErrAdapterClosed{}) } var luid uint64 - _, _, err := global.calln( - global.procGetAdapterLuid, + _, _, err := syscall.SyscallN( + procGetAdapterLUID.Addr(), a.handle, uintptr(unsafe.Pointer(&luid)), ) - if err != nil { + if err != windows.ERROR_SUCCESS { return 0, err } return winipcfg.LUID(luid), nil @@ -124,7 +129,7 @@ func (a *Adapter) Index() (int, error) { } func (a *Adapter) getReadWaitEvent() (windows.Handle, error) { - r0, _, err := a.sessionLocked(global.procGetReadWaitEvent) + r0, _, err := a.sessionLocked(procGetReadWaitEvent.Addr()) if r0 == 0 { return 0, err } @@ -141,7 +146,7 @@ func (a *Adapter) Recv(ctx context.Context) (ip rpack, err error) { var size uint32 for { r0, _, err := a.sessionLocked( - global.procReceivePacket, + procReceivePacket.Addr(), (uintptr)(unsafe.Pointer(&size)), ) @@ -187,7 +192,7 @@ func (a *Adapter) Release(p rpack) error { defer a.mu.RUnlock() _, _, err := a.sessionLocked( - global.procReleaseReceivePacket, + procReleaseReceivePacket.Addr(), uintptr(unsafe.Pointer(&p[0])), ) return err @@ -203,7 +208,7 @@ func (a *Adapter) Alloc(size int) (spack, error) { defer a.mu.RUnlock() r0, _, err := a.sessionLocked( - global.procAllocateSendPacket, + procAllocateSendPacket.Addr(), uintptr(size), ) if r0 == 0 { @@ -223,7 +228,7 @@ func (a *Adapter) Send(ip spack) error { defer a.mu.RUnlock() _, _, err := a.sessionLocked( - global.procSendPacket, + procSendPacket.Addr(), uintptr(unsafe.Pointer(&ip[0])), ) return err diff --git a/adapter_test.go b/adapter_test.go index 2f9f703..aebf24e 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -19,8 +19,7 @@ import ( ) func Test_Invalid_Ring_Capacity(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) t.Run("lesser", func(t *testing.T) { ap, err := wintun.CreateAdapter("testinvalidringlesser") @@ -47,8 +46,7 @@ func Test_Invalid_Ring_Capacity(t *testing.T) { } func Test_Adapter_Create(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) t.Run("create/start", func(t *testing.T) { ap, err := wintun.CreateAdapter("createstart") @@ -82,8 +80,7 @@ func Test_Adapter_Create(t *testing.T) { } func Test_Adapter_Stoped_Recv(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) ap, err := wintun.CreateAdapter("testadapterrwstoped") require.NoError(t, err) @@ -97,8 +94,7 @@ func Test_Adapter_Stoped_Recv(t *testing.T) { } func Test_Adapter_Index(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) name := "testadapterindex" @@ -122,8 +118,8 @@ func Test_Adapter_Index(t *testing.T) { } func Test_Recv(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) + var ( ip = netip.AddrFrom4([4]byte{10, 1, 1, 11}) laddr = &net.UDPAddr{IP: ip.AsSlice(), Port: randPort()} @@ -189,8 +185,7 @@ func Test_Recv(t *testing.T) { } func Test_RecvCtx(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) ap, err := wintun.CreateAdapter("rcecvctx") require.NoError(t, err) @@ -212,8 +207,7 @@ func Test_RecvCtx(t *testing.T) { func Test_Recving_Close(t *testing.T) { // if remove Close and Recv mutex, will fatal Exception - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) for i := 0; i < 0xf; i++ { func() { @@ -241,8 +235,8 @@ func Test_Recving_Close(t *testing.T) { } func Test_Echo_UDP_Adapter(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) + var ( ip = netip.AddrFrom4([4]byte{10, 0, 1, 3}) laddr = &net.UDPAddr{IP: ip.AsSlice(), Port: randPort()} @@ -316,8 +310,7 @@ func Test_Packet_Sniffing(t *testing.T) { t.Skip("todo:maybe not route") // route add 0.0.0.0 mask 0.0.0.0 10.0.1.3 metric 5 if 116 - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) var ( ip = netip.AddrFrom4([4]byte{10, 0, 1, 3}) @@ -379,8 +372,7 @@ func Test_Packet_Sniffing(t *testing.T) { } func Test_Session_Restart(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) ap, err := wintun.CreateAdapter("testsessionrestart") require.NoError(t, err) diff --git a/dll.go b/dll.go deleted file mode 100644 index 77e3ea4..0000000 --- a/dll.go +++ /dev/null @@ -1,65 +0,0 @@ -package wintun - -import ( - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/driver/memmod" -) - -type dll interface { - Release() error - FindProc(string) (uintptr, error) - MustFindProc(string) uintptr -} - -type file windows.DLL - -func loadFileDLL(path string) (dll, error) { - dll, err := windows.LoadDLL(path) - if err != nil { - return nil, err - } - return (*file)(dll), nil -} - -func (d *file) FindProc(name string) (uintptr, error) { - p, err := ((*windows.DLL)(d)).FindProc(name) - if err != nil { - return 0, err - } - return p.Addr(), nil -} -func (d *file) MustFindProc(name string) uintptr { - hdl, err := d.FindProc(name) - if err != nil { - panic(err) - } - return hdl -} -func (d *file) Release() error { - return ((*windows.DLL)(d)).Release() -} - -type mem memmod.Module - -func loadMemDLL(data []byte) (dll, error) { - d, err := memmod.LoadLibrary(data) - if err != nil { - return nil, err - } - return (*mem)(d), nil -} - -func (d *mem) FindProc(name string) (uintptr, error) { - return ((*memmod.Module)(d)).ProcAddressByName(name) -} -func (d *mem) MustFindProc(name string) uintptr { - hdl, err := d.FindProc(name) - if err != nil { - panic(err) - } - return hdl -} -func (d *mem) Release() error { - ((*memmod.Module)(d)).Free() - return nil -} diff --git a/embed_windows_386.go b/embed_windows_386.go index 0746427..b1dbbe1 100644 --- a/embed_windows_386.go +++ b/embed_windows_386.go @@ -4,5 +4,7 @@ import ( _ "embed" ) +// from https://www.wintun.net/builds/wintun-0.14.1.zip +// //go:embed embed/wintun_x86.dll var DLL Mem diff --git a/embed_windows_amd64.go b/embed_windows_amd64.go index 07199f6..6cb981b 100644 --- a/embed_windows_amd64.go +++ b/embed_windows_amd64.go @@ -4,7 +4,7 @@ import ( _ "embed" ) -// https://www.wintun.net/builds/wintun-0.14.1.zip - +// from https://www.wintun.net/builds/wintun-0.14.1.zip +// //go:embed embed/wintun_amd64.dll var DLL Mem diff --git a/embed_windows_arm.go b/embed_windows_arm.go index 6e435da..410eab7 100644 --- a/embed_windows_arm.go +++ b/embed_windows_arm.go @@ -4,5 +4,7 @@ import ( _ "embed" ) +// from https://www.wintun.net/builds/wintun-0.14.1.zip +// //go:embed embed/wintun_arm.dll var DLL Mem diff --git a/embed_windows_arm64.go b/embed_windows_arm64.go index 8f52b47..4045549 100644 --- a/embed_windows_arm64.go +++ b/embed_windows_arm64.go @@ -4,5 +4,7 @@ import ( _ "embed" ) +// from https://www.wintun.net/builds/wintun-0.14.1.zip +// //go:embed embed/wintun_arm64.dll var DLL Mem diff --git a/go.mod b/go.mod index 33f3e0d..60febbf 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( ) require ( + github.com/lysShub/divert-go v0.0.0-20240525230502-6f79596abd61 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 golang.org/x/sys v0.16.0 diff --git a/go.sum b/go.sum index 1ce76d7..f4d076d 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/lysShub/divert-go v0.0.0-20240525230502-6f79596abd61 h1:qqarPA8zZe+LnIGHaleqDikaQ3QzvlAfDigXYRrboHU= +github.com/lysShub/divert-go v0.0.0-20240525230502-6f79596abd61/go.mod h1:OXuD4Q/Y84FyNiYy/sf9RVshvAC5/rvcHA6J7JvvtFM= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/logger.go b/logger.go index 2e95909..350390b 100644 --- a/logger.go +++ b/logger.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "runtime" + "syscall" "github.com/pkg/errors" "golang.org/x/sys/windows" @@ -59,7 +60,7 @@ func SetLogger(logger LoggerCallback) error { } } - _, _, err := global.calln(global.procSetLogger, callback) + _, _, err := syscall.SyscallN(procSetLogger.Addr(), callback) if err != windows.ERROR_SUCCESS { return errors.WithStack(err) } diff --git a/readme.md b/readme.md index 318b062..e53fcc0 100644 --- a/readme.md +++ b/readme.md @@ -1,88 +1,88 @@ -# wintun-go - -golang client for [wintun](https://git.zx2c4.com/wintun/about/) - - -##### Example: -```golang -package main - -import ( - "context" - "log" - "net" - "net/netip" - - "github.com/lysShub/wintun-go" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "gvisor.dev/gvisor/pkg/tcpip/header" // go get gvisor.dev/gvisor@go -) - -// curl google.com -func main() { - wintun.MustLoad(wintun.DLL) - defer wintun.Release() - - ips, err := net.DefaultResolver.LookupIP(context.Background(), "ip4", "google.com") - if err != nil { - log.Fatal(err) - } - - ap, err := wintun.CreateAdapter("capture-google") - if err != nil { - log.Fatal(err) - } - defer ap.Close() - luid, err := ap.GetAdapterLuid() - if err != nil { - log.Fatal(err) - } - - var addr = netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 7, 3}), 24) - err = luid.SetIPAddresses([]netip.Prefix{addr}) - if err != nil { - log.Fatal(err) - } - - var routs []*winipcfg.RouteData - for _, e := range ips { - ip := netip.AddrFrom4([4]byte(e)) - dst := netip.PrefixFrom(ip, ip.BitLen()) - routs = append(routs, &winipcfg.RouteData{ - Destination: dst, - NextHop: addr.Addr(), - Metric: 5, - }) - } - err = luid.AddRoutes(routs) - if err != nil { - log.Fatal(err) - } - - for { - ip, err := ap.Receive(context.Background()) - if err != nil { - log.Fatal(err) - } - - if header.IPVersion(ip) == 4 { - iphdr := header.IPv4(ip) - if iphdr.TransportProtocol() == header.TCPProtocolNumber { - tcphdr := header.TCP(iphdr.Payload()) - - log.Printf("%s:%d --> %s:%d %s\n", - iphdr.SourceAddress().String(), tcphdr.SourcePort(), - iphdr.DestinationAddress().String(), tcphdr.DestinationPort(), - tcphdr.Flags(), - ) - } - } - - err = ap.ReleasePacket(ip) - if err != nil { - log.Fatal(err) - } - } -} -``` - +# wintun-go + +golang client for [wintun](https://git.zx2c4.com/wintun/about/) + + +##### Example: +```golang +package main + +import ( + "context" + "log" + "net" + "net/netip" + + "github.com/lysShub/wintun-go" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "gvisor.dev/gvisor/pkg/tcpip/header" // go get gvisor.dev/gvisor@go +) + +// curl google.com +func main() { + wintun.MustLoad(wintun.DLL) + + + ips, err := net.DefaultResolver.LookupIP(context.Background(), "ip4", "google.com") + if err != nil { + log.Fatal(err) + } + + ap, err := wintun.CreateAdapter("capture-google") + if err != nil { + log.Fatal(err) + } + defer ap.Close() + luid, err := ap.GetAdapterLuid() + if err != nil { + log.Fatal(err) + } + + var addr = netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 7, 3}), 24) + err = luid.SetIPAddresses([]netip.Prefix{addr}) + if err != nil { + log.Fatal(err) + } + + var routs []*winipcfg.RouteData + for _, e := range ips { + ip := netip.AddrFrom4([4]byte(e)) + dst := netip.PrefixFrom(ip, ip.BitLen()) + routs = append(routs, &winipcfg.RouteData{ + Destination: dst, + NextHop: addr.Addr(), + Metric: 5, + }) + } + err = luid.AddRoutes(routs) + if err != nil { + log.Fatal(err) + } + + for { + ip, err := ap.Receive(context.Background()) + if err != nil { + log.Fatal(err) + } + + if header.IPVersion(ip) == 4 { + iphdr := header.IPv4(ip) + if iphdr.TransportProtocol() == header.TCPProtocolNumber { + tcphdr := header.TCP(iphdr.Payload()) + + log.Printf("%s:%d --> %s:%d %s\n", + iphdr.SourceAddress().String(), tcphdr.SourcePort(), + iphdr.DestinationAddress().String(), tcphdr.DestinationPort(), + tcphdr.Flags(), + ) + } + } + + err = ap.ReleasePacket(ip) + if err != nil { + log.Fatal(err) + } + } +} +``` + diff --git a/wintun.go b/wintun.go index 4e4e08a..9bb1b13 100644 --- a/wintun.go +++ b/wintun.go @@ -4,149 +4,56 @@ package wintun import ( - "sync" "syscall" "unsafe" + "github.com/lysShub/divert-go/dll" "github.com/pkg/errors" "golang.org/x/sys/windows" ) -var global = wintun{} - func MustLoad[T string | Mem](p T) struct{} { err := Load(p) - if err != nil { + if err != nil && !errors.Is(err, ErrLoaded{}) { panic(err) } return struct{}{} } func Load[T string | Mem](p T) error { - global.Lock() - defer global.Unlock() - if global.dll != nil { + if wintun.Loaded() { return ErrLoaded{} } - var err error switch p := any(p).(type) { case string: - global.dll, err = loadFileDLL(p) - if err != nil { - return errors.WithStack(err) - } + dll.ResetLazyDll(wintun, p) case Mem: - global.dll, err = loadMemDLL(p) - if err != nil { - return errors.WithStack(err) - } + dll.ResetLazyDll(wintun, []byte(p)) default: - return windows.ERROR_INVALID_PARAMETER - } - - err = global.init() - return errors.WithStack(err) -} - -func Release() error { - global.Lock() - defer global.Unlock() - if global.dll == nil { - return nil - } - - err := global.dll.Release() - global.dll = nil - return errors.WithStack(err) -} - -type wintun struct { - sync.RWMutex - dll dll - - procCreateAdapter uintptr - procOpenAdapter uintptr - procCloseAdapter uintptr - procDeleteDriver uintptr - procGetAdapterLuid uintptr - procGetRunningDriverVersion uintptr - procSetLogger uintptr - procStartSession uintptr - procEndSession uintptr - procGetReadWaitEvent uintptr - procReceivePacket uintptr - procReleaseReceivePacket uintptr - procAllocateSendPacket uintptr - procSendPacket uintptr -} - -func (w *wintun) init() (err error) { - if global.procCreateAdapter, err = global.dll.FindProc("WintunCreateAdapter"); err != nil { - goto ret - } - if global.procOpenAdapter, err = global.dll.FindProc("WintunOpenAdapter"); err != nil { - goto ret - } - if global.procCloseAdapter, err = global.dll.FindProc("WintunCloseAdapter"); err != nil { - goto ret - } - if global.procDeleteDriver, err = global.dll.FindProc("WintunDeleteDriver"); err != nil { - goto ret - } - if global.procGetAdapterLuid, err = global.dll.FindProc("WintunGetAdapterLUID"); err != nil { - goto ret - } - if global.procGetRunningDriverVersion, err = global.dll.FindProc("WintunGetRunningDriverVersion"); err != nil { - goto ret - } - if global.procSetLogger, err = global.dll.FindProc("WintunSetLogger"); err != nil { - goto ret - } - if global.procStartSession, err = global.dll.FindProc("WintunStartSession"); err != nil { - goto ret + panic("") } - if global.procEndSession, err = global.dll.FindProc("WintunEndSession"); err != nil { - goto ret - } - if global.procGetReadWaitEvent, err = global.dll.FindProc("WintunGetReadWaitEvent"); err != nil { - goto ret - } - if global.procReceivePacket, err = global.dll.FindProc("WintunReceivePacket"); err != nil { - goto ret - } - if global.procReleaseReceivePacket, err = global.dll.FindProc("WintunReleaseReceivePacket"); err != nil { - goto ret - } - if global.procAllocateSendPacket, err = global.dll.FindProc("WintunAllocateSendPacket"); err != nil { - goto ret - } - if global.procSendPacket, err = global.dll.FindProc("WintunSendPacket"); err != nil { - goto ret - } - -ret: - if err != nil { - w.dll.Release() - w.dll = nil - } - return err + return nil } -func (w *wintun) calln(trap uintptr, args ...uintptr) (r1, r2 uintptr, err error) { - w.RLock() - defer w.RUnlock() - if w.dll == nil || trap == 0 { - return 0, 0, errors.WithStack(ErrNotLoad{}) - } - - var e syscall.Errno - r1, r2, e = syscall.SyscallN(trap, args...) - if e == windows.ERROR_SUCCESS { - return r1, r2, nil - } - return r1, r2, errors.WithStack(e) -} +var ( + wintun = dll.NewLazyDLL("wintun.dll") + + procCreateAdapter = wintun.NewProc("WintunCreateAdapter") + procOpenAdapter = wintun.NewProc("WintunOpenAdapter") + procCloseAdapter = wintun.NewProc("WintunCloseAdapter") + procDeleteDriver = wintun.NewProc("WintunDeleteDriver") + procGetAdapterLUID = wintun.NewProc("WintunGetAdapterLUID") + procGetRunningDriverVersion = wintun.NewProc("WintunGetRunningDriverVersion") + procSetLogger = wintun.NewProc("WintunSetLogger") + procStartSession = wintun.NewProc("WintunStartSession") + procEndSession = wintun.NewProc("WintunEndSession") + procGetReadWaitEvent = wintun.NewProc("WintunGetReadWaitEvent") + procReceivePacket = wintun.NewProc("WintunReceivePacket") + procReleaseReceivePacket = wintun.NewProc("WintunReleaseReceivePacket") + procAllocateSendPacket = wintun.NewProc("WintunAllocateSendPacket") + procSendPacket = wintun.NewProc("WintunSendPacket") +) func CreateAdapter(name string, opts ...Option) (*Adapter, error) { if len(name) == 0 { @@ -167,13 +74,13 @@ func CreateAdapter(name string, opts ...Option) (*Adapter, error) { return nil, errors.WithStack(err) } - r1, _, err := global.calln( - global.procCreateAdapter, + r1, _, err := syscall.SyscallN( + procCreateAdapter.Addr(), uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(o.guid)), ) - if err != nil { + if err != windows.ERROR_SUCCESS { return nil, errors.WithStack(err) } ap := &Adapter{handle: r1} @@ -191,8 +98,8 @@ func OpenAdapter(name string) (*Adapter, error) { return nil, errors.WithStack(err) } - r1, _, err := global.calln(global.procOpenAdapter, uintptr(unsafe.Pointer(name16))) - if err != nil { + r1, _, err := syscall.SyscallN(procOpenAdapter.Addr(), uintptr(unsafe.Pointer(name16))) + if err != windows.ERROR_SUCCESS { return nil, errors.WithStack(err) } ap := &Adapter{handle: r1} @@ -201,16 +108,16 @@ func OpenAdapter(name string) (*Adapter, error) { // todo: https://git.zx2c4.com/wintun-go/tree/wintun.go func DriverVersion() (version uint32, err error) { - r0, _, err := global.calln(global.procGetRunningDriverVersion) - if err != nil { + r0, _, err := syscall.SyscallN(procGetRunningDriverVersion.Addr()) + if err != windows.ERROR_SUCCESS { return 0, errors.WithStack(err) } return uint32(r0), nil } func DeleteDriver() error { - _, _, err := global.calln(global.procDeleteDriver) - if err != nil { + _, _, err := syscall.SyscallN(procDeleteDriver.Addr()) + if err != windows.ERROR_SUCCESS { return errors.WithStack(err) } return nil diff --git a/wintun_test.go b/wintun_test.go index d449f00..fdbe7a7 100644 --- a/wintun_test.go +++ b/wintun_test.go @@ -8,6 +8,7 @@ import ( "math/rand" "net/netip" "os" + "os/exec" "runtime" "testing" "time" @@ -19,65 +20,59 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -func randPort() int { - for { - port := uint16(rand.Uint32()) - if port > 2048 && port < 0xffff-0xff { - return int(port) - } - } +func Test_Gofmt(t *testing.T) { + cmd := exec.Command("cmd", "/C", "gofmt", "-l", "-w", `.`) + out, err := cmd.CombinedOutput() + + require.NoError(t, err) + require.Empty(t, string(out)) } -var dllPath = `.\embed\wintun_amd64.dll` +func Test_Load(t *testing.T) { + // go test -list ".*" ./... + // and go test with -run flag + t.Skip("require independent test") -func init() { - switch runtime.GOARCH { - case "amd64": - case "386": - dllPath = `.\embed\wintun_386.dll` - case "arm": - dllPath = `.\embed\wintun_arm.dll` - case "arm64": - dllPath = `.\embed\wintun_arm64.dll` - default: - panic("") - } -} + t.Run("mem:load-fail", func(t *testing.T) { + require.NoError(t, wintun.Load(make(wintun.Mem, 64))) -func buildICMP(t require.TestingT, src, dst []byte, typ header.ICMPv4Type, msg []byte) []byte { - require.Zero(t, len(msg)%4) + ap, err := wintun.CreateAdapter("testload0") + require.Error(t, err) + require.Nil(t, ap) + }) + t.Run("file:load-fail", func(t *testing.T) { + require.NoError(t, wintun.Load("./wintun.go")) - var p = make([]byte, 28+len(msg)) - iphdr := header.IPv4(p) - iphdr.Encode(&header.IPv4Fields{ - TOS: 0, - TotalLength: uint16(len(p)), - ID: uint16(rand.Uint32()), - Flags: 0, - FragmentOffset: 0, - TTL: 128, - Protocol: uint8(header.ICMPv4ProtocolNumber), - Checksum: 0, - SrcAddr: tcpip.AddrFromSlice(src), - DstAddr: tcpip.AddrFromSlice(dst), + ap, err := wintun.CreateAdapter("testload1") + require.Error(t, err) + require.Nil(t, ap) + }) + + t.Run("load-fail/load", func(t *testing.T) { + require.NoError(t, wintun.Load(make(wintun.Mem, 64))) + + ap, err := wintun.CreateAdapter("testload2") + require.Error(t, err) + require.Nil(t, ap) + + require.NoError(t, wintun.Load(dllPath)) + }) + + t.Run("load/load", func(t *testing.T) { + wintun.MustLoad(wintun.DLL) + + err := wintun.Load(wintun.DLL) + require.True(t, errors.Is(err, wintun.ErrLoaded{})) + require.True(t, + err.(interface{ Temporary() bool }).Temporary(), + ) }) - iphdr.SetChecksum(^checksum.Checksum(p[:iphdr.HeaderLength()], 0)) - require.True(t, iphdr.IsChecksumValid()) - icmphdr := header.ICMPv4(iphdr.Payload()) - icmphdr.SetType(typ) - icmphdr.SetIdent(0) - icmphdr.SetSequence(0) - icmphdr.SetChecksum(0) - copy(icmphdr.Payload(), msg) - icmphdr.SetChecksum(^checksum.Checksum(icmphdr, 0)) - return p } func Test_Example(t *testing.T) { // https://github.com/WireGuard/wintun/blob/master/example/example.c - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) // 10.6.7.7/24 var addr = netip.PrefixFrom( @@ -155,8 +150,7 @@ func Test_DriverVersion(t *testing.T) { t.Skip("can't get driver version") t.Run("mem", func(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) ver, err := wintun.DriverVersion() require.NoError(t, err) @@ -164,7 +158,6 @@ func Test_DriverVersion(t *testing.T) { }) t.Run("file", func(t *testing.T) { require.NoError(t, wintun.Load(dllPath)) - defer wintun.Release() ver, err := wintun.DriverVersion() require.NoError(t, err) @@ -174,8 +167,7 @@ func Test_DriverVersion(t *testing.T) { func Test_Logger(t *testing.T) { t.Run("mem", func(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() + wintun.MustLoad(wintun.DLL) buff := bytes.NewBuffer(nil) log := slog.New(slog.NewJSONHandler(buff, nil)) @@ -194,8 +186,8 @@ func Test_Logger(t *testing.T) { require.Contains(t, buff.String(), "Creating") }) t.Run("file", func(t *testing.T) { + t.Skip("require independent test") require.NoError(t, wintun.Load(dllPath)) - defer wintun.Release() buff := bytes.NewBuffer(nil) log := slog.New(slog.NewJSONHandler(buff, nil)) @@ -215,80 +207,66 @@ func Test_Logger(t *testing.T) { }) } -func Test_Load(t *testing.T) { - t.Run("mem:load-release/load-release", func(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - err := wintun.Release() - require.NoError(t, err) - - require.NoError(t, wintun.Load(wintun.DLL)) - err = wintun.Release() - require.NoError(t, err) - }) - t.Run("file:load-release/load-release", func(t *testing.T) { - require.NoError(t, wintun.Load(dllPath)) - err := wintun.Release() - require.NoError(t, err) - - require.NoError(t, wintun.Load(dllPath)) - err = wintun.Release() - require.NoError(t, err) - }) - - t.Run("mem:load-fail", func(t *testing.T) { - err := wintun.Load(make(wintun.Mem, 64)) - require.Error(t, err) - }) - t.Run("file:load-fail", func(t *testing.T) { - err := wintun.Load("./wintun.go") - require.Error(t, err) - }) - - t.Run("load-fail/load", func(t *testing.T) { - err := wintun.Load(make(wintun.Mem, 64)) - require.Error(t, err) - - require.NoError(t, wintun.Load(dllPath)) - defer wintun.Release() +func Test_Open(t *testing.T) { + t.Run("notload/open", func(t *testing.T) { + t.Skip("require independent test") + ap, err := wintun.OpenAdapter("xxx") + require.True(t, errors.Is(err, wintun.ErrNotLoad{})) + require.Nil(t, ap) }) - t.Run("load-fail/release", func(t *testing.T) { - err := wintun.Load(make(wintun.Mem, 64)) - require.Error(t, err) +} - require.NoError(t, wintun.Release()) - }) +func randPort() int { + for { + port := uint16(rand.Uint32()) + if port > 2048 && port < 0xffff-0xff { + return int(port) + } + } +} - t.Run("load/load", func(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - defer wintun.Release() +var dllPath = `.\embed\wintun_amd64.dll` - err := wintun.Load(wintun.DLL) - require.True(t, errors.Is(err, wintun.ErrLoaded{})) - require.True(t, - err.(interface{ Temporary() bool }).Temporary(), - ) - }) - t.Run("release/ralease", func(t *testing.T) { - err := wintun.Release() - require.NoError(t, err) +func init() { + switch runtime.GOARCH { + case "amd64": + case "386": + dllPath = `.\embed\wintun_386.dll` + case "arm": + dllPath = `.\embed\wintun_arm.dll` + case "arm64": + dllPath = `.\embed\wintun_arm64.dll` + default: + panic("") + } +} - err = wintun.Release() - require.NoError(t, err) - }) - t.Run("load/release/ralease", func(t *testing.T) { - require.NoError(t, wintun.Load(wintun.DLL)) - err := wintun.Release() - require.NoError(t, err) +func buildICMP(t require.TestingT, src, dst []byte, typ header.ICMPv4Type, msg []byte) []byte { + require.Zero(t, len(msg)%4) - err = wintun.Release() - require.NoError(t, err) + var p = make([]byte, 28+len(msg)) + iphdr := header.IPv4(p) + iphdr.Encode(&header.IPv4Fields{ + TOS: 0, + TotalLength: uint16(len(p)), + ID: uint16(rand.Uint32()), + Flags: 0, + FragmentOffset: 0, + TTL: 128, + Protocol: uint8(header.ICMPv4ProtocolNumber), + Checksum: 0, + SrcAddr: tcpip.AddrFromSlice(src), + DstAddr: tcpip.AddrFromSlice(dst), }) -} + iphdr.SetChecksum(^checksum.Checksum(p[:iphdr.HeaderLength()], 0)) + require.True(t, iphdr.IsChecksumValid()) -func Test_Open(t *testing.T) { - t.Run("notload/open", func(t *testing.T) { - ap, err := wintun.OpenAdapter("xxx") - require.True(t, errors.Is(err, wintun.ErrNotLoad{})) - require.Nil(t, ap) - }) + icmphdr := header.ICMPv4(iphdr.Payload()) + icmphdr.SetType(typ) + icmphdr.SetIdent(0) + icmphdr.SetSequence(0) + icmphdr.SetChecksum(0) + copy(icmphdr.Payload(), msg) + icmphdr.SetChecksum(^checksum.Checksum(icmphdr, 0)) + return p }