Skip to content

Commit

Permalink
Add publisher extension
Browse files Browse the repository at this point in the history
  • Loading branch information
amessing-bdai committed Feb 5, 2025
1 parent 5bff590 commit 2ab3f4d
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
80 changes: 80 additions & 0 deletions synchros2/synchros2/publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2025 Boston Dynamics AI Institute Inc. All rights reserved.

from asyncio import Future
from typing import Any, Generic, Optional, Type, TypeVar

import rclpy.publisher
from rclpy.node import Node

import synchros2.scope as scope

# Represents a ros message type
_MessageT = TypeVar("_MessageT")

class Publisher(Generic[_MessageT]):
"""An extension of a publisher from ROS 2."""

def __init__(self, *args: Any, node: Optional[Node] = None, **kwargs: Any) -> None:
"""Initializes the Publisher
Args:
args: Positional arguments to pass to the `Node.create_publisher` function
node: Optional node for the underlying native subscription, defaults to
the current process node.
kwargs: Keyword arguments to pass to the `Node.create_publisher` function
"""
if node is None:
node = scope.ensure_node()
self._node = node
self._publisher = self._node.create_publisher(*args, **kwargs)

def subscription_matches(self, num_subscriptions: int) -> Future:
"""Gets a future to next publisher matching status update.
Note that in ROS 2 Humble and earlier distributions, this method relies on
polling the number of known subscriptions for the topic subscribed, as publisher
matching events are missing.
Args:
num_subscriptions: lower bound on the number of subscriptions to match.
Returns:
a future, done if the current number of subscriptions already matches
the specified lower bound.
"""
future_match = Future()
num_matched_publishers = self._node.count_subscribers(self._publisher.topic_name)
if num_matched_publishers < num_subscriptions:

def _poll_publisher_matches() -> None:
nonlocal future_match, num_subscriptions
if future_match.cancelled():
return
num_matched_subscriptions = self._node.count_subscribers(self._publisher.topic_name)
if num_subscriptions <= num_matched_subscriptions:
future_match.set_result(num_matched_subscriptions)

timer = self._node.create_timer(0.1, _poll_publisher_matches)
future_match.add_done_callback(lambda _: self._node.destroy_timer(timer))
else:
future_match.set_result(num_matched_publishers)
return future_match

@property
def matched_subscriptions(self) -> int:
"""Gets the number subscriptions matched and linked to.
Note that in ROS 2 Humble and earlier distributions, this property
relies on the number of known subscriptions for the topic subscribed
as subscription matching status info is missing.
"""
return self._node.count_subscribers(self._publisher.topic_name)

@property
def message_type(self) -> Type[_MessageT]:
"""Gets the type of the message subscribed."""
return self._publisher.msg_type

def publisher(self) -> rclpy.publisher.Publisher[_MessageT]:
"""Returns the internal ROS 2 publisher"""
return self._publisher
2 changes: 1 addition & 1 deletion synchros2/synchros2/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def subscriber(self) -> Subscriber:
def publisher_matches(self, num_publishers: int) -> Future:
"""Gets a future to next publisher matching status update.
Note that in ROS 2 Humble and ealier distributions, this method relies on
Note that in ROS 2 Humble and earlier distributions, this method relies on
polling the number of known publishers for the topic subscribed, as subscription
matching events are missing.
Expand Down
26 changes: 26 additions & 0 deletions synchros2/test/test_publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2025 Boston Dynamics AI Institute Inc. All rights reserved.

import std_msgs.msg
from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile
from synchros2.futures import wait_for_future
from synchros2.publisher import Publisher
from synchros2.scope import ROSAwareScope

DEFAULT_QOS_PROFILE = QoSProfile(
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_ALL,
depth=1,
)

def test_publisher_matching_subscriptions(ros: ROSAwareScope) -> None:
"""Asserts that checking for subscription matching on a publisher works as expected."""
assert ros.node is not None
sequence = Publisher(std_msgs.msg.Int8, "sequence", qos_profile=DEFAULT_QOS_PROFILE, node=ros.node)
assert sequence.matched_subscriptions == 0
future = sequence.subscription_matches(1)
assert not future.done()
future.cancel()

ros.node.create_subscription(std_msgs.msg.Int8, "sequence", qos_profile=DEFAULT_QOS_PROFILE, callback=lambda msg: None)
assert wait_for_future(sequence.subscription_matches(1), timeout_sec=5.0)
assert sequence.matched_subscriptions == 1

0 comments on commit 2ab3f4d

Please sign in to comment.