Skip to content

PyTorch interoperability

mhidalgo-bdai edited this page May 29, 2025 · 1 revision

In general, a ROS 2 node may trivially wrap a PyTorch model for inference (e.g. in a callback). One notable exception to this are nodes spinning on multi-threaded executors -- as by default when using synchros2 -- and performing inference on GPUs. Thread-local contexts and CPU <> GPU synchronization make it so that, for best performance out of the box, models must always run on the same thread and never concurrently with others. ROS 2 and synchros2 afford a couple idioms to deal with these constraints.

Idioms

For illustrative purposes, code snippets below use sample models as listed in the appendix.

Foreground inference only

Single-threaded execution precludes the aforementioned issues. This make it best suited for simple model wrappers:

# sample_node.py

from typing import Any
from sensor_msgs.msg import Image

from synchros2.node import Node
from synchros2.executors import foreground
import synchros2.process as ros_process 

from rclpy.executors import SingleThreadedExecutor

from sample_models import MaskFormerROS

class MaskFormerROSNode(Node):

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__("sample_node", *args, **kwargs)
        self.segmentation = MaskFormerROS()
        self.pub = self.create_publisher(Image, "~/output/image", 1)
        self.sub = self.create_subscription(Image, "~/input/image", self.on_input_callback, 1)

    def on_input_callback(self, message: Image) -> None:
        self.pub.publish(self.segmentation.perform(message))

@ros_process.main(prebaked=False)
def main():
    with foreground(SingleThreadedExecutor()) as main.executor: 
        main.spin(MaskFormerROSNode) 

if __name__ == "__main__":
    main()

Background inference, foreground threads

A single-threaded executor spinning in the background may be used for generic work dispatch. This can be handy in multi-threaded applications:

# sample_node.py

from rclpy.executors import SingleThreadedExecutor

from sensor_msgs.msg import Image
from synchros2.executors import background
from synchros2.futures import unwrap_future
import synchros2.process as ros_process 

from sample_models import MaskFormerROS

@ros_process.main(autospin=False)
def main():
    segmentation = MaskFormerROS()
    with background(SingleThreadedExecutor()) as background_executor:
        pub = main.node.create_publisher(Image, "~/output/image", 1)
        def on_input_callback(message: Image) -> None:
            pub.publish(unwrap_future(background_executor.create_task(segmentation.perform, message)))
        main.node.create_subscription(Image, "~/input/image", on_input_callback, 1)
        main.spin()  # until Ctrl + C 

if __name__ == "__main__":
    main()

Background threads, foreground inference

Conversely, synchros2 abstractions and patterns may be leveraged to bring back the simpler, linear code paths:

# sample_node.py

import contextlib

from sensor_msgs.msg import Image

import synchros2.process as ros_process
from synchros2.publisher import Publisher
from synchros2.subscription import Subscription

from sample_models import MaskFormerROS

@ros_process.main()
def main():
    segmentation = MaskFormerROS()
    publisher = Publisher(Image, "~/output/image")
    subscription = Subscription(Image, "~/input/image")
    with contextlib.closing(subscription.stream()) as stream:
        for image in stream:  # indefinitely until Ctrl + C
            publisher.publish(segmentation.perform(image))

if __name__ == "__main__":
    main()

Callback groups with thread affinity

For more complex (or reusable) setups, when there's less control over execution paths, synchros2 executors support thread affinity settings for callback groups. That is, one or more callback groups may be configured to be served by specific thread pools of one or more workers (typically one when dealing with inference and GPU workloads in general):

# sample_node.py

from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
from sensor_msgs.msg import Image

import synchros2.process as ros_process

from sample_models import MaskFormerROS

@ros_process.main(autospin=False)
def main():
    thread_affine_callback_group = MutuallyExclusiveCallbackGroup()
    thread_pool = main.executor.add_static_thread_pool(1)
    main.executor.bind(thread_affine_callback_group, thread_pool)

    segmentation = MaskFormerROS()
    pub = main.node.create_publisher(Image, "~/output/image", 1)

    def on_input_callback(message: Image) -> None:
        pub.publish(segmentation.perform(message))

    main.node.create_subscription(
        Image, "~/input/image", on_input_callback, 1,
        callback_group=thread_affine_callback_group
    )

    main.spin()  # until Ctrl + C

if __name__ == "__main__":
    main()

Appendix

Below, a sample pretrained segmentation model wrapped to interface with ROS messages:

# sample_models.py

import cv2
import numpy as np
import matplotlib.pyplot as plt

from cv_bridge import CvBridge
from sensor_msgs.msg import Image

from transformers import (
    AutoImageProcessor, 
    Mask2FormerForUniversalSegmentation,
)
import torch


def labels2rgb(labels: np.ndarray) -> np.ndarray:
    label_range = np.arange(np.min(labels), np.max(labels))
    lut = np.zeros((256, 1, 3), dtype=np.uint8)
    lut[:label_range[-1], 0, :] = np.uint8(
        256 * plt.cm.tab20(label_range / label_range[-1])[:,:-1]
    )
    return cv2.LUT(cv2.merge((labels, labels, labels)), lut)


class MaskFormerROS:

    bridge = CvBridge()

    def __init__(self) -> None:
        self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic")
        self.model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-ade-semantic")

    def perform(self, message: Image) -> Image:
        image = self.bridge.imgmsg_to_cv2(message)
        
        inputs = self.image_processor(image, return_tensors="pt")

        with torch.no_grad():
            outputs = self.model(**inputs)

        class_queries_logits = outputs.class_queries_logits
        masks_queries_logits = outputs.masks_queries_logits

        pred_semantic_map = self.image_processor.post_process_semantic_segmentation(
            outputs, target_sizes=[image.shape]
        )[0].numpy().astype(np.uint8)

        return self.bridge.cv2_to_imgmsg(labels2rgb(pred_semantic_map), "rgb8")
Clone this wiki locally