Skip to content

Commit

Permalink
Merge pull request #122 from BiffoBear/fix_socket_leaks
Browse files Browse the repository at this point in the history
Fix DHCP socket leak
  • Loading branch information
FoamyGuy authored Jul 21, 2023
2 parents 30aa274 + 91d1f31 commit b564228
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 110 deletions.
71 changes: 40 additions & 31 deletions adafruit_wiznet5k/adafruit_wiznet5k.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
_MR_RST = const(0x80) # Mode Register RST
# Socket mode register
_SNMR_CLOSE = const(0x00)
SNMR_TCP = const(0x21)
_SNMR_TCP = const(0x21)
SNMR_UDP = const(0x02)
_SNMR_IPRAW = const(0x03)
_SNMR_MACRAW = const(0x04)
Expand Down Expand Up @@ -492,7 +492,7 @@ def ifconfig(

# *** Public Socket Methods ***

def socket_available(self, socket_num: int, sock_type: int = SNMR_TCP) -> int:
def socket_available(self, socket_num: int, sock_type: int = _SNMR_TCP) -> int:
"""
Number of bytes available to be read from the socket.
Expand All @@ -514,7 +514,7 @@ def socket_available(self, socket_num: int, sock_type: int = SNMR_TCP) -> int:
self._sock_num_in_range(socket_num)

number_of_bytes = self._get_rx_rcv_size(socket_num)
if self.read_snsr(socket_num) == SNMR_UDP:
if self._read_snsr(socket_num) == SNMR_UDP:
number_of_bytes -= 8 # Subtract UDP header from packet size.
if number_of_bytes < 0:
raise ValueError("Negative number of bytes found on socket.")
Expand All @@ -533,14 +533,14 @@ def socket_status(self, socket_num: int) -> int:
:return int: The connection status.
"""
return self.read_snsr(socket_num)
return self._read_snsr(socket_num)

def socket_connect(
self,
socket_num: int,
dest: IpAddress4Raw,
port: int,
conn_mode: int = SNMR_TCP,
conn_mode: int = _SNMR_TCP,
) -> int:
"""
Open and verify a connection from a socket to a destination IPv4 address
Expand All @@ -567,11 +567,11 @@ def socket_connect(
# initialize a socket and set the mode
self.socket_open(socket_num, conn_mode=conn_mode)
# set socket destination IP and port
self.write_sndipr(socket_num, dest)
self.write_sndport(socket_num, port)
self.write_sncr(socket_num, _CMD_SOCK_CONNECT)
self._write_sndipr(socket_num, dest)
self._write_sndport(socket_num, port)
self._write_sncr(socket_num, _CMD_SOCK_CONNECT)

if conn_mode == SNMR_TCP:
if conn_mode == _SNMR_TCP:
# wait for tcp connection establishment
while self.socket_status(socket_num) != SNSR_SOCK_ESTABLISHED:
time.sleep(0.001)
Expand Down Expand Up @@ -638,7 +638,7 @@ def release_socket(self, socket_number):
WIZNET5K._sockets_reserved[socket_number - 1] = False

def socket_listen(
self, socket_num: int, port: int, conn_mode: int = SNMR_TCP
self, socket_num: int, port: int, conn_mode: int = _SNMR_TCP
) -> None:
"""
Listen on a socket's port.
Expand All @@ -665,15 +665,15 @@ def socket_listen(
self.socket_open(socket_num, conn_mode=conn_mode)
self.src_port = 0
# Send listen command
self.write_sncr(socket_num, _CMD_SOCK_LISTEN)
self._write_sncr(socket_num, _CMD_SOCK_LISTEN)
# Wait until ready
status = SNSR_SOCK_CLOSED
while status not in (
SNSR_SOCK_LISTEN,
SNSR_SOCK_ESTABLISHED,
_SNSR_SOCK_UDP,
):
status = self.read_snsr(socket_num)
status = self._read_snsr(socket_num)
if status == SNSR_SOCK_CLOSED:
raise RuntimeError("Listening socket closed.")

Expand Down Expand Up @@ -703,7 +703,7 @@ def socket_accept(self, socket_num: int) -> Tuple[int, Tuple[str, int]]:
)
return next_socknum, (dest_ip, dest_port)

def socket_open(self, socket_num: int, conn_mode: int = SNMR_TCP) -> None:
def socket_open(self, socket_num: int, conn_mode: int = _SNMR_TCP) -> None:
"""
Open an IP socket.
Expand All @@ -720,7 +720,7 @@ def socket_open(self, socket_num: int, conn_mode: int = SNMR_TCP) -> None:
self._sock_num_in_range(socket_num)
self._check_link_status()
debug_msg("*** Opening socket {}".format(socket_num), self._debug)
if self.read_snsr(socket_num) not in (
if self._read_snsr(socket_num) not in (
SNSR_SOCK_CLOSED,
SNSR_SOCK_TIME_WAIT,
SNSR_SOCK_FIN_WAIT,
Expand All @@ -732,22 +732,22 @@ def socket_open(self, socket_num: int, conn_mode: int = SNMR_TCP) -> None:
debug_msg("* Opening W5k Socket, protocol={}".format(conn_mode), self._debug)
time.sleep(0.00025)

self.write_snmr(socket_num, conn_mode)
self._write_snmr(socket_num, conn_mode)
self.write_snir(socket_num, 0xFF)

if self.src_port > 0:
# write to socket source port
self.write_sock_port(socket_num, self.src_port)
self._write_sock_port(socket_num, self.src_port)
else:
s_port = randint(49152, 65535)
while s_port in self._src_ports_in_use:
s_port = randint(49152, 65535)
self.write_sock_port(socket_num, s_port)
self._write_sock_port(socket_num, s_port)
self._src_ports_in_use[socket_num] = s_port

# open socket
self.write_sncr(socket_num, _CMD_SOCK_OPEN)
if self.read_snsr(socket_num) not in [_SNSR_SOCK_INIT, _SNSR_SOCK_UDP]:
self._write_sncr(socket_num, _CMD_SOCK_OPEN)
if self._read_snsr(socket_num) not in [_SNSR_SOCK_INIT, _SNSR_SOCK_UDP]:
raise RuntimeError("Could not open socket in TCP or UDP mode.")

def socket_close(self, socket_num: int) -> None:
Expand All @@ -760,14 +760,14 @@ def socket_close(self, socket_num: int) -> None:
"""
debug_msg("*** Closing socket {}".format(socket_num), self._debug)
self._sock_num_in_range(socket_num)
self.write_sncr(socket_num, _CMD_SOCK_CLOSE)
self._write_sncr(socket_num, _CMD_SOCK_CLOSE)
debug_msg(" Waiting for socket to close…", self._debug)
timeout = time.monotonic() + 5.0
while self.read_snsr(socket_num) != SNSR_SOCK_CLOSED:
while self._read_snsr(socket_num) != SNSR_SOCK_CLOSED:
if time.monotonic() > timeout:
raise RuntimeError(
"Wiznet5k failed to close socket, status = {}.".format(
self.read_snsr(socket_num)
self._read_snsr(socket_num)
)
)
time.sleep(0.0001)
Expand All @@ -783,7 +783,7 @@ def socket_disconnect(self, socket_num: int) -> None:
"""
debug_msg("*** Disconnecting socket {}".format(socket_num), self._debug)
self._sock_num_in_range(socket_num)
self.write_sncr(socket_num, _CMD_SOCK_DISCON)
self._write_sncr(socket_num, _CMD_SOCK_DISCON)

def socket_read(self, socket_num: int, length: int) -> Tuple[int, bytes]:
"""
Expand Down Expand Up @@ -819,7 +819,7 @@ def socket_read(self, socket_num: int, length: int) -> Tuple[int, bytes]:
# After reading the received data, update Sn_RX_RD register.
pointer = (pointer + bytes_on_socket) & 0xFFFF
self._write_snrx_rd(socket_num, pointer)
self.write_sncr(socket_num, _CMD_SOCK_RECV)
self._write_sncr(socket_num, _CMD_SOCK_RECV)
else:
# no data on socket
if self._read_snmr(socket_num) in (
Expand Down Expand Up @@ -906,7 +906,7 @@ def socket_write(
# update sn_tx_wr to the value + data size
pointer = (pointer + bytes_to_write) & 0xFFFF
self._write_sntx_wr(socket_num, pointer)
self.write_sncr(socket_num, _CMD_SOCK_SEND)
self._write_sncr(socket_num, _CMD_SOCK_SEND)

# check data was transferred correctly
while not self.read_snir(socket_num) & _SNIR_SEND_OK:
Expand Down Expand Up @@ -1057,6 +1057,11 @@ def _check_link_status(self):
if not self.link_status:
raise ConnectionError("The Ethernet connection is down.")

@staticmethod
def _read_socket_reservations() -> list[int]:
"""Return the list of reserved sockets."""
return WIZNET5K._sockets_reserved

def _read_mr(self) -> int:
"""Read from the Mode Register (MR)."""
return int.from_bytes(self._read(_REG_MR[self._chip_type], 0x00), "big")
Expand Down Expand Up @@ -1175,18 +1180,22 @@ def _read_sndipr(self, sock) -> bytes:
)
return bytes(data)

def write_sndipr(self, sock: int, ip_addr: bytes) -> None:
def _write_sndipr(self, sock: int, ip_addr: bytes) -> None:
"""Write to socket destination IP Address."""
for offset, value in enumerate(ip_addr):
self._write_socket_register(
sock, _REG_SNDIPR[self._chip_type] + offset, value
)

def write_sndport(self, sock: int, port: int) -> None:
def _read_sndport(self, sock: int) -> int:
"""Read socket destination port."""
return self._read_two_byte_sock_reg(sock, _REG_SNDPORT[self._chip_type])

def _write_sndport(self, sock: int, port: int) -> None:
"""Write to socket destination port."""
self._write_two_byte_sock_reg(sock, _REG_SNDPORT[self._chip_type], port)

def read_snsr(self, sock: int) -> int:
def _read_snsr(self, sock: int) -> int:
"""Read Socket n Status Register."""
return self._read_socket_register(sock, _REG_SNSR[self._chip_type])

Expand All @@ -1202,15 +1211,15 @@ def _read_snmr(self, sock: int) -> int:
"""Read the socket MR register."""
return self._read_socket_register(sock, _REG_SNMR)

def write_snmr(self, sock: int, protocol: int) -> None:
def _write_snmr(self, sock: int, protocol: int) -> None:
"""Write to Socket n Mode Register."""
self._write_socket_register(sock, _REG_SNMR, protocol)

def write_sock_port(self, sock: int, port: int) -> None:
def _write_sock_port(self, sock: int, port: int) -> None:
"""Write to the socket port number."""
self._write_two_byte_sock_reg(sock, _REG_SNPORT[self._chip_type], port)

def write_sncr(self, sock: int, data: int) -> None:
def _write_sncr(self, sock: int, data: int) -> None:
"""Write to socket command register."""
self._write_socket_register(sock, _REG_SNCR[self._chip_type], data)
# Wait for command to complete before continuing.
Expand Down
Loading

0 comments on commit b564228

Please sign in to comment.