Skip to content

Commit 708d2b2

Browse files
committed
Enable to publish rects
1 parent 97ee892 commit 708d2b2

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

node_script/node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
import rospy
5-
from jsk_recognition_msgs.msg import LabelArray, VectorArray
5+
from jsk_recognition_msgs.msg import LabelArray, RectArray, VectorArray
66
from node_config import NodeConfig
77
from rospy import Publisher, Subscriber
88
from sensor_msgs.msg import Image
@@ -24,6 +24,7 @@ class DeticRosNode:
2424
pub_segimg: Optional[Publisher]
2525
pub_labels: Optional[Publisher]
2626
pub_score: Optional[Publisher]
27+
pub_rects: Optional[Publisher]
2728

2829
# otherwise, the following publisher will be used
2930
pub_info: Optional[Publisher]
@@ -45,6 +46,7 @@ def __init__(self, node_config: Optional[NodeConfig] = None):
4546
self.pub_segimg = rospy.Publisher('~segmentation', Image, queue_size=1)
4647
self.pub_labels = rospy.Publisher('~detected_classes', LabelArray, queue_size=1)
4748
self.pub_score = rospy.Publisher('~score', VectorArray, queue_size=1)
49+
self.pub_rects = rospy.Publisher('~rects', RectArray, queue_size=1)
4850
else:
4951
self.pub_info = rospy.Publisher('~segmentation_info', SegmentationInfo,
5052
queue_size=1)
@@ -77,9 +79,11 @@ def callback_image(self, msg: Image):
7779
seg_img = raw_result.get_ros_segmentaion_image()
7880
labels = raw_result.get_label_array()
7981
scores = raw_result.get_score_array()
82+
rects = raw_result.get_rect_array()
8083
self.pub_segimg.publish(seg_img)
8184
self.pub_labels.publish(labels)
8285
self.pub_score.publish(scores)
86+
self.pub_rects.publish(rects)
8387
else:
8488
assert self.pub_info is not None
8589
seg_info = raw_result.get_segmentation_info()

node_script/wrapper.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from cv_bridge import CvBridge
1111
from detectron2.utils.visualizer import VisImage
1212
from detic.predictor import VisualizationDemo
13-
from jsk_recognition_msgs.msg import Label, LabelArray, VectorArray
13+
from jsk_recognition_msgs.msg import Label, LabelArray, Rect, RectArray, VectorArray
1414
from node_config import NodeConfig
1515
from sensor_msgs.msg import Image
1616
from std_msgs.msg import Header
@@ -28,6 +28,7 @@ class InferenceRawResult:
2828
visualization: Optional[VisImage]
2929
header: Header
3030
detected_class_names: List[str]
31+
boxes: List[List[float]]
3132

3233
def get_ros_segmentaion_image(self) -> Image:
3334
seg_img = _cv_bridge.cv2_to_imgmsg(self.segmentation_raw_image, encoding="32SC1")
@@ -68,6 +69,14 @@ def get_segmentation_info(self) -> SegmentationInfo:
6869
header=self.header)
6970
return seg_info
7071

72+
def get_rect_array(self) -> RectArray:
73+
rects = [Rect(x=int(box[0]),
74+
y=int(box[1]),
75+
width=int(box[2] - box[0]),
76+
height=int(box[3] - box[1])) for box in self.boxes]
77+
rec_arr = RectArray(header=self.header, rects=rects)
78+
return rec_arr
79+
7180

7281
class DeticWrapper:
7382
predictor: VisualizationDemo
@@ -122,12 +131,14 @@ def infer(self, msg: Image) -> InferenceRawResult:
122131
pred_masks = list(instances.pred_masks)
123132
scores = instances.scores.tolist()
124133
class_indices = instances.pred_classes.tolist()
134+
boxes = list(instances.pred_boxes)
125135

126136
if len(scores) > 0 and self.node_config.output_highest:
127137
best_index = np.argmax(scores)
128138
pred_masks = [pred_masks[best_index]]
129139
scores = [scores[best_index]]
130140
class_indices = [class_indices[best_index]]
141+
boxes = [boxes[best_index]]
131142

132143
if self.node_config.verbose:
133144
rospy.loginfo("{} with highest score {}".format(self.class_names[class_indices[0]], scores[best_index]))
@@ -150,5 +161,6 @@ def infer(self, msg: Image) -> InferenceRawResult:
150161
scores,
151162
visualized_output,
152163
msg.header,
153-
detected_classes_names)
164+
detected_classes_names,
165+
boxes)
154166
return result

0 commit comments

Comments
 (0)