-
Notifications
You must be signed in to change notification settings - Fork 3
PyTorch interoperability
In general, a ROS 2 node may trivially wrap a PyTorch model for inference (e.g. in a callback). One notable exception to this are nodes spinning on multi-threaded executors -- as by default when using synchros2
-- and performing inference on GPUs. Thread-local contexts and CPU <> GPU synchronization make it so that, for best performance out of the box, models must always run on the same thread and never concurrently with others. ROS 2 and synchros2
afford a couple idioms to deal with these constraints.
For illustrative purposes, code snippets below use sample models as listed in the appendix.
Single-threaded execution precludes the aforementioned issues. This make it best suited for simple model wrappers:
# sample_node.py
from typing import Any
from sensor_msgs.msg import Image
from synchros2.node import Node
from synchros2.executors import foreground
import synchros2.process as ros_process
from rclpy.executors import SingleThreadedExecutor
from sample_models import MaskFormerROS
class MaskFormerROSNode(Node):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__("sample_node", *args, **kwargs)
self.segmentation = MaskFormerROS()
self.pub = self.create_publisher(Image, "~/output/image", 1)
self.sub = self.create_subscription(Image, "~/input/image", self.on_input_callback, 1)
def on_input_callback(self, message: Image) -> None:
self.pub.publish(self.segmentation.perform(message))
@ros_process.main(prebaked=False)
def main():
with foreground(SingleThreadedExecutor()) as main.executor:
main.spin(MaskFormerROSNode)
if __name__ == "__main__":
main()
A single-threaded executor spinning in the background may be used for generic work dispatch. This can be handy in multi-threaded applications:
# sample_node.py
from rclpy.executors import SingleThreadedExecutor
from sensor_msgs.msg import Image
from synchros2.executors import background
from synchros2.futures import unwrap_future
import synchros2.process as ros_process
from sample_models import MaskFormerROS
@ros_process.main(autospin=False)
def main():
segmentation = MaskFormerROS()
with background(SingleThreadedExecutor()) as background_executor:
pub = main.node.create_publisher(Image, "~/output/image", 1)
def on_input_callback(message: Image) -> None:
pub.publish(unwrap_future(background_executor.create_task(segmentation.perform, message)))
main.node.create_subscription(Image, "~/input/image", on_input_callback, 1)
main.spin() # until Ctrl + C
if __name__ == "__main__":
main()
Conversely, synchros2
abstractions and patterns may be leveraged to bring back the simpler, linear code paths:
# sample_node.py
import contextlib
from sensor_msgs.msg import Image
import synchros2.process as ros_process
from synchros2.publisher import Publisher
from synchros2.subscription import Subscription
from sample_models import MaskFormerROS
@ros_process.main()
def main():
segmentation = MaskFormerROS()
publisher = Publisher(Image, "~/output/image")
subscription = Subscription(Image, "~/input/image")
with contextlib.closing(subscription.stream()) as stream:
for image in stream: # indefinitely until Ctrl + C
publisher.publish(segmentation.perform(image))
if __name__ == "__main__":
main()
For more complex (or reusable) setups, when there's less control over execution paths, synchros2
executors support thread affinity settings for callback groups. That is, one or more callback groups may be configured to be served by specific thread pools of one or more workers (typically one when dealing with inference and GPU workloads in general):
# sample_node.py
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
from sensor_msgs.msg import Image
import synchros2.process as ros_process
from sample_models import MaskFormerROS
@ros_process.main(autospin=False)
def main():
thread_affine_callback_group = MutuallyExclusiveCallbackGroup()
thread_pool = main.executor.add_static_thread_pool(1)
main.executor.bind(thread_affine_callback_group, thread_pool)
segmentation = MaskFormerROS()
pub = main.node.create_publisher(Image, "~/output/image", 1)
def on_input_callback(message: Image) -> None:
pub.publish(segmentation.perform(message))
main.node.create_subscription(
Image, "~/input/image", on_input_callback, 1,
callback_group=thread_affine_callback_group
)
main.spin() # until Ctrl + C
if __name__ == "__main__":
main()
Below, a sample pretrained segmentation model wrapped to interface with ROS messages:
# sample_models.py
import cv2
import numpy as np
import matplotlib.pyplot as plt
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from transformers import (
AutoImageProcessor,
Mask2FormerForUniversalSegmentation,
)
import torch
def labels2rgb(labels: np.ndarray) -> np.ndarray:
label_range = np.arange(np.min(labels), np.max(labels))
lut = np.zeros((256, 1, 3), dtype=np.uint8)
lut[:label_range[-1], 0, :] = np.uint8(
256 * plt.cm.tab20(label_range / label_range[-1])[:,:-1]
)
return cv2.LUT(cv2.merge((labels, labels, labels)), lut)
class MaskFormerROS:
bridge = CvBridge()
def __init__(self) -> None:
self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic")
self.model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-ade-semantic")
def perform(self, message: Image) -> Image:
image = self.bridge.imgmsg_to_cv2(message)
inputs = self.image_processor(image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
pred_semantic_map = self.image_processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.shape]
)[0].numpy().astype(np.uint8)
return self.bridge.cv2_to_imgmsg(labels2rgb(pred_semantic_map), "rgb8")