Skip to content

Commit 7ada09c

Browse files
committed
Refactor Visuomotor to VLA and build more of VLA test app.
1 parent 3771fc0 commit 7ada09c

File tree

8 files changed

+115
-216
lines changed

8 files changed

+115
-216
lines changed
Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package us.ihmc.rdx.ui.lerobot;
22

3-
import behavior_msgs.msg.dds.VisuomotorOperationMessage;
3+
import behavior_msgs.msg.dds.VLAOperationMessage;
44
import com.badlogic.gdx.graphics.Color;
55
import com.badlogic.gdx.graphics.g3d.Renderable;
66
import com.badlogic.gdx.utils.Array;
@@ -14,7 +14,7 @@
1414
import us.ihmc.communication.crdt.LatestTimestampModifiable;
1515
import us.ihmc.communication.ros2.ROS2ActorDesignation;
1616
import us.ihmc.communication.ros2.sync.ROS2PeerClockOffsetEstimator;
17-
import us.ihmc.lerobot.VisuomotorPolicyUpdateThread;
17+
import us.ihmc.lerobot.VLAUpdateThread;
1818
import us.ihmc.rdx.imgui.ImGuiAveragedFrequencyText;
1919
import us.ihmc.rdx.imgui.ImGuiTools;
2020
import us.ihmc.rdx.imgui.ImGuiUniqueLabelMap;
@@ -25,35 +25,33 @@
2525
import us.ihmc.ros2.ROS2Node;
2626
import us.ihmc.ros2.ROS2Publisher;
2727

28-
import static us.ihmc.lerobot.VisuomotorPolicyUpdateThread.OPERATOR_UI;
28+
import static us.ihmc.lerobot.VLAUpdateThread.UI;
2929

3030
/**
31-
* UI for remotely operating {@link VisuomotorPolicyUpdateThread}.
31+
* UI for remotely operating {@link VLAUpdateThread}.
3232
*/
33-
public class RDXVisuomotorOperation
33+
public class RDXVLAOperation
3434
{
3535
private final ImGuiUniqueLabelMap labels = new ImGuiUniqueLabelMap(getClass());
3636
private final Throttler commandThrottler = new Throttler().setFrequency(30.0);
3737
private final LatestTimestampModifiable latestTimestampModifiable;
3838
private final CRDTBidirectionalBoolean running;
3939
private final CRDTBidirectionalBoolean controlRobot;
40-
private double pythonStatusFrequency = 0.0;
41-
private long receivedActions = 0L;
4240
private String statusMessage = "Not yet connected to robot";
43-
private final TypedNotification<VisuomotorOperationMessage> statusSubscription;
44-
private final ROS2Publisher<VisuomotorOperationMessage> commandPublisher;
41+
private final TypedNotification<VLAOperationMessage> statusSubscription;
42+
private final ROS2Publisher<VLAOperationMessage> commandPublisher;
4543
private final ImGuiAveragedFrequencyText commsFrequencyText = new ImGuiAveragedFrequencyText();
4644
private final SideDependentList<RDXReferenceFrameGraphic> actionHandPoseGraphics = new SideDependentList<>();
4745
private final SideDependentList<RDXReferenceFrameGraphic> actionForearmPoseGraphics = new SideDependentList<>();
4846

49-
public RDXVisuomotorOperation(ROS2Node ros2Node, ROS2PeerClockOffsetEstimator peerClockEstimator)
47+
public RDXVLAOperation(ROS2Node ros2Node, ROS2PeerClockOffsetEstimator peerClockEstimator)
5048
{
5149
latestTimestampModifiable = new LatestTimestampModifiable(new CRDTInfo(ROS2ActorDesignation.OPERATOR, peerClockEstimator));
5250
running = new CRDTBidirectionalBoolean(latestTimestampModifiable, false);
5351
controlRobot = new CRDTBidirectionalBoolean(latestTimestampModifiable, false);
5452

55-
statusSubscription = ROS2Tools.createNotificationSubscription(ros2Node, OPERATOR_UI.getTopic(ROS2ActorDesignation.OPERATOR.getIncomingQualifier()));
56-
commandPublisher = ros2Node.createPublisher(OPERATOR_UI.getTopic(ROS2ActorDesignation.OPERATOR.getOutgoingQualifier()));
53+
statusSubscription = ROS2Tools.createNotificationSubscription(ros2Node, UI.getTopic(ROS2ActorDesignation.OPERATOR.getIncomingQualifier()));
54+
commandPublisher = ros2Node.createPublisher(UI.getTopic(ROS2ActorDesignation.OPERATOR.getOutgoingQualifier()));
5755
}
5856

5957
public void create(RDXBaseUI baseUI)
@@ -65,7 +63,7 @@ public void create(RDXBaseUI baseUI)
6563
}
6664

6765
baseUI.getPrimaryScene().addRenderableProvider(this::getRenderables);
68-
baseUI.getImGuiPanelManager().addPanel("Visuomotor Inference", this::renderImGuiWidgets);
66+
baseUI.getImGuiPanelManager().addPanel("VLA Operation", this::renderImGuiWidgets);
6967
}
7068

7169
public void update()
@@ -78,7 +76,7 @@ public void renderImGuiWidgets()
7876
if (statusSubscription.poll())
7977
{
8078
commsFrequencyText.ping();
81-
VisuomotorOperationMessage status = statusSubscription.read();
79+
VLAOperationMessage status = statusSubscription.read();
8280
latestTimestampModifiable.fromMessage(status.getLatestTimestampModifiable());
8381
running.fromMessage(status.getRunning());
8482
controlRobot.fromMessage(status.getControlRobot());
@@ -87,13 +85,11 @@ public void renderImGuiWidgets()
8785
actionHandPoseGraphics.get(side).setPoseInWorldFrame(status.getActionHandPoses()[side.ordinal()]);
8886
actionForearmPoseGraphics.get(side).setPoseInWorldFrame(status.getActionForearmPoses()[side.ordinal()]);
8987
}
90-
pythonStatusFrequency = status.getPythonStatusFrequency();
91-
statusMessage = status.getPythonStatusMessageAsString();
92-
receivedActions = status.getReceivedActions();
88+
statusMessage = status.getStatusMessageAsString();
9389
}
9490

95-
ImGui.text("Update Thread: %s Python: %3d Hz Actions: %d".formatted(commsFrequencyText.getText(), (int) pythonStatusFrequency, receivedActions));
96-
ImGui.text("Python status: " + statusMessage);
91+
ImGui.text("Update Thread: %s".formatted(commsFrequencyText.getText()));
92+
ImGui.text("Python status: %s".formatted(statusMessage));
9793
if (ImGui.checkbox(labels.get("Run inference"), running.getValue()))
9894
running.setValue(!running.getValue());
9995
ImGui.beginDisabled(running.getValue());
@@ -114,7 +110,7 @@ public void renderImGuiWidgets()
114110

115111
if (commandThrottler.run())
116112
{
117-
VisuomotorOperationMessage command = new VisuomotorOperationMessage();
113+
VLAOperationMessage command = new VLAOperationMessage();
118114
latestTimestampModifiable.toMessage(command.getLatestTimestampModifiable());
119115
command.setRunning(running.toMessage());
120116
command.setControlRobot(controlRobot.toMessage());
Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package us.ihmc.lerobot;
22

3-
import behavior_msgs.msg.dds.VisuomotorOperationMessage;
3+
import behavior_msgs.msg.dds.VLAOperationMessage;
44
import org.bytedeco.opencv.global.opencv_core;
55
import org.bytedeco.opencv.global.opencv_imgproc;
66
import org.bytedeco.opencv.opencv_core.Mat;
@@ -34,7 +34,6 @@
3434
import us.ihmc.robotics.robotSide.RobotSide;
3535
import us.ihmc.robotics.robotSide.SideDependentList;
3636
import us.ihmc.ros2.ROS2Node;
37-
import us.ihmc.ros2.ROS2NodeBuilder;
3837
import us.ihmc.ros2.ROS2Publisher;
3938
import us.ihmc.ros2.ROS2Topic;
4039
import us.ihmc.sensors.ImageSensor;
@@ -46,15 +45,13 @@
4645
import java.util.concurrent.CompletableFuture;
4746

4847
/**
49-
* Autonomy process thread for managing visuomotor inference and supporting remote UI.
50-
* Manages communication with the Python side, which is running the LeRobot code
51-
* with pytorch inference of the visuomotor policy. We use a ROS 2 API to interface with it.
48+
* Autonomy process thread for managing vision-language-action (VLA) inference and supporting remote UI.
49+
* Manages communication with the Python side, which is running the openpi.
5250
*/
53-
public class VisuomotorPolicyUpdateThread extends RepeatingTaskThread
51+
public class VLAUpdateThread extends RepeatingTaskThread
5452
{
55-
public static final ROS2IOTopicPair<VisuomotorOperationMessage> OPERATOR_UI
56-
= new ROS2IOTopicPair<>(new ROS2Topic<>().withPrefix("lerobot_ui").withTypeName(VisuomotorOperationMessage.class));
57-
53+
public static final ROS2IOTopicPair<VLAOperationMessage> UI = new ROS2IOTopicPair<>(new ROS2Topic<>().withPrefix("vla_ui")
54+
.withTypeName(VLAOperationMessage.class));
5855
private final ROS2SyncedRobotModel syncedRobot;
5956
private final ImageSensor zedSensor;
6057
private String status = "Not connected to openpi";
@@ -73,23 +70,23 @@ public class VisuomotorPolicyUpdateThread extends RepeatingTaskThread
7370
private final Throttler actionThrottler = new Throttler().setFrequency(5.0);
7471
private final int planSize = 5;
7572

76-
private final ROS2Node ros2Node = new ROS2NodeBuilder().build("visuomotor_update_thread");
7773
private final LatestTimestampModifiable latestTimestampModifiable;
7874
private long sequenceID = 0L;
7975
private final CRDTBidirectionalBoolean running;
8076
private final CRDTBidirectionalBoolean controlRobot;
81-
private final TypedNotification<VisuomotorOperationMessage> uiCommandSubscription;
82-
private final ROS2Publisher<VisuomotorOperationMessage> uiStatusPublisher;
77+
private final TypedNotification<VLAOperationMessage> uiCommandSubscription;
78+
private final ROS2Publisher<VLAOperationMessage> uiStatusPublisher;
8379

8480
private final ROS2Publisher<KinematicsStreamingToolboxInputMessage> kstInputPublisher;
8581
private final ROS2Publisher<ToolboxStateMessage> kstStatePublisher;
8682

87-
public VisuomotorPolicyUpdateThread(ROS2PeerClockOffsetEstimator clockOffsetEstimator,
88-
DRCRobotModel robotModel,
89-
ROS2SyncedRobotModel syncedRobot,
90-
ImageSensor zedSensor)
83+
public VLAUpdateThread(ROS2Node ros2Node,
84+
ROS2PeerClockOffsetEstimator clockOffsetEstimator,
85+
DRCRobotModel robotModel,
86+
ROS2SyncedRobotModel syncedRobot,
87+
ImageSensor zedSensor)
9188
{
92-
super(VisuomotorPolicyUpdateThread.class.getSimpleName());
89+
super(VLAUpdateThread.class.getSimpleName());
9390

9491
this.syncedRobot = syncedRobot;
9592
this.zedSensor = zedSensor;
@@ -104,8 +101,8 @@ public VisuomotorPolicyUpdateThread(ROS2PeerClockOffsetEstimator clockOffsetEsti
104101
running = new CRDTBidirectionalBoolean(latestTimestampModifiable, false);
105102
controlRobot = new CRDTBidirectionalBoolean(latestTimestampModifiable, false);
106103

107-
uiCommandSubscription = ROS2Tools.createNotificationSubscription(ros2Node, OPERATOR_UI.getTopic(ROS2ActorDesignation.ROBOT.getIncomingQualifier()));
108-
uiStatusPublisher = ros2Node.createPublisher(OPERATOR_UI.getTopic(ROS2ActorDesignation.ROBOT.getOutgoingQualifier()));
104+
uiCommandSubscription = ROS2Tools.createNotificationSubscription(ros2Node, UI.getTopic(ROS2ActorDesignation.ROBOT.getIncomingQualifier()));
105+
uiStatusPublisher = ros2Node.createPublisher(UI.getTopic(ROS2ActorDesignation.ROBOT.getOutgoingQualifier()));
109106

110107
kstInputPublisher = ros2Node.createPublisher(ToolboxAPIs.getIKStreamingInputTopic(robotModel.getSimpleRobotName()));
111108
kstStatePublisher = ros2Node.createPublisher(ToolboxAPIs.getIKStreamingStateTopic(robotModel.getSimpleRobotName()));
@@ -116,7 +113,7 @@ protected void runTask()
116113
{
117114
if (uiCommandSubscription.poll())
118115
{
119-
VisuomotorOperationMessage uiCommand = uiCommandSubscription.read();
116+
VLAOperationMessage uiCommand = uiCommandSubscription.read();
120117
latestTimestampModifiable.fromMessage(uiCommand.getLatestTimestampModifiable());
121118
boolean wasRunning = running.getValue();
122119
running.fromMessage(uiCommand.getRunning());
@@ -300,7 +297,7 @@ else if (openpiRequest.isDone())
300297
status = "Not running";
301298
}
302299

303-
VisuomotorOperationMessage uiStatus = new VisuomotorOperationMessage();
300+
VLAOperationMessage uiStatus = new VLAOperationMessage();
304301
latestTimestampModifiable.toMessage(uiStatus.getLatestTimestampModifiable());
305302
uiStatus.setSequenceId(sequenceID++);
306303
uiStatus.setRunning(running.toMessage());
@@ -310,15 +307,12 @@ else if (openpiRequest.isDone())
310307
uiStatus.getActionHandPoses()[side.ordinal()].set(actionHandPoses.get(side));
311308
uiStatus.getActionForearmPoses()[side.ordinal()].set(actionForearmPoses.get(side));
312309
}
313-
uiStatus.setPythonStatusFrequency(statusFrequency.getFrequencyDecaying());
314-
uiStatus.setPythonStatusMessage(status);
315-
uiStatus.setReceivedActions(numberOfActionsReceived);
310+
uiStatus.setStatusMessage("%-30s Actions: %d".formatted(status, numberOfActionsReceived));
316311
uiStatusPublisher.publish(uiStatus);
317312
}
318313

319314
public void destroy()
320315
{
321316
blockingKill();
322-
ros2Node.destroy();
323317
}
324318
}
Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef __behavior_msgs__msg__VisuomotorOperationMessage__idl__
2-
#define __behavior_msgs__msg__VisuomotorOperationMessage__idl__
1+
#ifndef __behavior_msgs__msg__VLAOperationMessage__idl__
2+
#define __behavior_msgs__msg__VLAOperationMessage__idl__
33

44
#include "geometry_msgs/msg/./Pose_.idl"
55
#include "ihmc_common_msgs/msg/./LatestModificationMessage_.idl"
@@ -11,10 +11,10 @@ module behavior_msgs
1111
{
1212

1313
/**
14-
* A message for remotely operating the visuomotor policies
14+
* A message for remotely operating the VLA policies
1515
*/
16-
@TypeCode(type="behavior_msgs::msg::dds_::VisuomotorOperationMessage_")
17-
struct VisuomotorOperationMessage
16+
@TypeCode(type="behavior_msgs::msg::dds_::VLAOperationMessage_")
17+
struct VLAOperationMessage
1818
{
1919
ihmc_common_msgs::msg::dds::LatestModificationMessage latest_timestamp_modifiable;
2020
/**
@@ -38,17 +38,9 @@ module behavior_msgs
3838
*/
3939
geometry_msgs::msg::dds::Pose action_forearm_poses[2];
4040
/**
41-
* Frequency of status from python side
41+
* Status message from robot thread to show to the operator
4242
*/
43-
double python_status_frequency;
44-
/**
45-
* Message received from python
46-
*/
47-
string python_status_message;
48-
/**
49-
* Number of output actions received from policy
50-
*/
51-
unsigned long received_actions;
43+
string status_message;
5244
};
5345
};
5446
};

0 commit comments

Comments
 (0)