Skip to content

Commit 022588b

Browse files
committed
Add conversion between full and active index in DofMask.
1 parent 7047d0e commit 022588b

File tree

3 files changed

+140
-24
lines changed

3 files changed

+140
-24
lines changed

planning/dof_mask.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "drake/planning/dof_mask.h"
22

3+
#include "planning/dof_mask.h"
34
#include <fmt/format.h>
45
#include <fmt/ranges.h>
56

@@ -10,12 +11,36 @@ using multibody::Joint;
1011
using multibody::JointIndex;
1112
using multibody::ModelInstanceIndex;
1213
using multibody::MultibodyPlant;
14+
namespace {
15+
void SetFullAndActiveIndices(
16+
const std::vector<bool>& data, std::vector<int>* full_indices,
17+
std::unordered_map<int, int>* full_to_active_index) {
18+
full_indices->clear();
19+
full_to_active_index->clear();
20+
int count = 0;
21+
for (int i = 0; i < std::ssize(data); ++i) {
22+
if (data[i]) {
23+
full_indices->push_back(i);
24+
full_to_active_index->emplace(i, count);
25+
++count;
26+
}
27+
}
28+
}
29+
} // namespace
1330

1431
DofMask::DofMask() = default;
1532

1633
DofMask::DofMask(int size, bool value)
1734
: data_(size >= 0 ? size : 0, value), count_(value ? size : 0) {
1835
DRAKE_THROW_UNLESS(size >= 0);
36+
if (value) {
37+
full_indices_.reserve(size);
38+
full_to_active_index_.reserve(size);
39+
for (int i = 0; i < size; ++i) {
40+
full_indices_.push_back(i);
41+
full_to_active_index_.emplace(i, i);
42+
}
43+
}
1944
}
2045

2146
DofMask::DofMask(std::initializer_list<bool> values)
@@ -27,6 +52,7 @@ DofMask::DofMask(std::vector<bool> values) : data_(std::move(values)) {
2752
++count_;
2853
}
2954
}
55+
SetFullAndActiveIndices(data_, &full_indices_, &full_to_active_index_);
3056
}
3157

3258
DofMask DofMask::MakeFromModel(const MultibodyPlant<double>& plant,
@@ -74,6 +100,8 @@ DofMask DofMask::Complement() const {
74100
DofMask result(*this);
75101
result.data_.flip();
76102
result.count_ = this->size() - this->count();
103+
SetFullAndActiveIndices(result.data_, &result.full_indices_,
104+
&result.full_to_active_index_);
77105
return result;
78106
}
79107

@@ -88,6 +116,8 @@ DofMask DofMask::Union(const DofMask& other) const {
88116
++result.count_;
89117
}
90118
}
119+
SetFullAndActiveIndices(result.data_, &result.full_indices_,
120+
&result.full_to_active_index_);
91121
return result;
92122
}
93123

@@ -102,6 +132,8 @@ DofMask DofMask::Intersect(const DofMask& other) const {
102132
++result.count_;
103133
}
104134
}
135+
SetFullAndActiveIndices(result.data_, &result.full_indices_,
136+
&result.full_to_active_index_);
105137
return result;
106138
}
107139

@@ -116,6 +148,8 @@ DofMask DofMask::Subtract(const DofMask& other) const {
116148
++result.count_;
117149
}
118150
}
151+
SetFullAndActiveIndices(result.data_, &result.full_indices_,
152+
&result.full_to_active_index_);
119153
return result;
120154
}
121155

@@ -203,5 +237,18 @@ bool DofMask::operator==(const DofMask& o) const {
203237
return result;
204238
}
205239

240+
std::optional<int> DofMask::full_to_active_index(int i) const {
241+
DRAKE_THROW_UNLESS(i >= 0 && i < this->size());
242+
if (data_[i]) {
243+
return full_to_active_index_.at(i);
244+
}
245+
return std::nullopt;
246+
}
247+
248+
int DofMask::active_to_full_index(int i) const {
249+
DRAKE_THROW_UNLESS(i >= 0 && i < this->count());
250+
return full_indices_[i];
251+
}
252+
206253
} // namespace planning
207254
} // namespace drake

planning/dof_mask.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <initializer_list>
44
#include <string>
5+
#include <unordered_map>
56
#include <utility>
67
#include <vector>
78

@@ -234,6 +235,18 @@ class DofMask {
234235
@pre `output.size() == size()`. */
235236
void SetInArray(const Eigen::Ref<const Eigen::VectorXd>& vec,
236237
drake::EigenPtr<Eigen::VectorXd> output) const;
238+
239+
/** Returns the index in the full vector. Namely q_active[i] =
240+
q_full[dof_mask.active_to_full_index(i)].
241+
@pre i >= 0 and i < this->count()
242+
*/
243+
[[nodiscard]] int active_to_full_index(int i) const;
244+
245+
/** Returns the index in the active vector. Namely
246+
q_active[dof_mask.full_to_active_index(i)] = q_full[i]. If (*this)[i] ==
247+
false (namely this DoF is inactive), then return a nullopt.
248+
@pre i >= 0, i < this->size(). */
249+
[[nodiscard]] std::optional<int> full_to_active_index(int i) const;
237250
//@}
238251

239252
private:
@@ -246,6 +259,12 @@ class DofMask {
246259
// These member fields are almost "const" -- we have no member functions that
247260
// mutate them, other than the two default assignment operators.
248261
std::vector<bool> data_;
262+
// Maps the active index to the full index, namely q_active[i] =
263+
// q_full[full_indices_[i]].
264+
std::vector<int> full_indices_;
265+
// Maps the full index to the active index, namely
266+
// q_active[full_to_active_index_[i]] = q_full[i].
267+
std::unordered_map<int, int> full_to_active_index_;
249268
reset_after_move<int> count_{0};
250269
};
251270

planning/test/dof_mask_test.cc

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,55 @@ using multibody::ModelInstanceIndex;
2626

2727
/* Exercise all of the constructors and evaluate the size-like APIs. */
2828
GTEST_TEST(DofMaskTest, ConstructorsAndSize) {
29-
const DofMask dut1;
30-
EXPECT_EQ(dut1.count(), 0);
31-
EXPECT_EQ(dut1.size(), 0);
32-
33-
const DofMask dut2(13, true);
34-
EXPECT_EQ(dut2.count(), 13);
35-
EXPECT_EQ(dut2.size(), 13);
36-
37-
const DofMask dut3(14, false);
38-
EXPECT_EQ(dut3.count(), 0);
39-
EXPECT_EQ(dut3.size(), 14);
40-
41-
const DofMask dut4({true, false, true, true, false});
42-
EXPECT_EQ(dut4.count(), 3);
43-
EXPECT_EQ(dut4.size(), 5);
44-
45-
std::vector<bool> bits{true, true, false, true, true, false};
46-
const DofMask dut5(bits);
47-
EXPECT_EQ(dut5.count(), 4);
48-
EXPECT_EQ(dut5.size(), 6);
29+
{
30+
const DofMask dut1;
31+
EXPECT_EQ(dut1.count(), 0);
32+
EXPECT_EQ(dut1.size(), 0);
33+
}
34+
{
35+
const DofMask dut2(13, true);
36+
EXPECT_EQ(dut2.count(), 13);
37+
EXPECT_EQ(dut2.size(), 13);
38+
for (int i = 0; i < dut2.size(); ++i) {
39+
EXPECT_EQ(dut2.full_to_active_index(i), i);
40+
EXPECT_EQ(dut2.active_to_full_index(i), i);
41+
}
42+
}
43+
{
44+
const DofMask dut3(14, false);
45+
EXPECT_EQ(dut3.count(), 0);
46+
EXPECT_EQ(dut3.size(), 14);
47+
for (int i = 0; i < dut3.size(); ++i) {
48+
EXPECT_FALSE(dut3.full_to_active_index(i).has_value());
49+
EXPECT_THROW(drake::unused(dut3.active_to_full_index(i)), std::exception);
50+
}
51+
}
52+
{
53+
const DofMask dut4({true, false, true, true, false});
54+
EXPECT_EQ(dut4.count(), 3);
55+
EXPECT_EQ(dut4.size(), 5);
56+
57+
EXPECT_EQ(dut4.full_to_active_index(0).value(), 0);
58+
EXPECT_FALSE(dut4.full_to_active_index(1).has_value());
59+
EXPECT_EQ(dut4.full_to_active_index(2), 1);
60+
EXPECT_EQ(dut4.full_to_active_index(3), 2);
61+
EXPECT_FALSE(dut4.full_to_active_index(4).has_value());
62+
EXPECT_EQ(dut4.active_to_full_index(0), 0);
63+
EXPECT_EQ(dut4.active_to_full_index(1), 2);
64+
EXPECT_EQ(dut4.active_to_full_index(2), 3);
65+
}
66+
{
67+
std::vector<bool> bits{true, true, false, true, true, false};
68+
const DofMask dut5(bits);
69+
EXPECT_EQ(dut5.count(), 4);
70+
EXPECT_EQ(dut5.size(), 6);
71+
EXPECT_EQ(dut5.full_to_active_index(0), 0);
72+
EXPECT_EQ(dut5.full_to_active_index(1), 1);
73+
EXPECT_FALSE(dut5.full_to_active_index(2).has_value());
74+
EXPECT_EQ(dut5.full_to_active_index(3), 2);
75+
EXPECT_EQ(dut5.full_to_active_index(4), 3);
76+
EXPECT_FALSE(dut5.full_to_active_index(5).has_value());
77+
}
4978
}
5079

5180
// In addition to testing move/copy semantics, this also provides testing for
@@ -183,35 +212,56 @@ GTEST_TEST(DofMaskTest, ToString) {
183212
EXPECT_EQ(dut.to_string(), fmt::to_string(int_bits));
184213
}
185214

215+
namespace {
216+
void ExpectFullAndActiveIndicesEqual(const DofMask& mask1,
217+
const DofMask& mask2) {
218+
ASSERT_EQ(mask1.count(), mask2.count());
219+
ASSERT_EQ(mask1.size(), mask2.size());
220+
for (int i = 0; i < mask1.size(); ++i) {
221+
EXPECT_EQ(mask1.full_to_active_index(i), mask2.full_to_active_index(i));
222+
}
223+
for (int i = 0; i < mask1.count(); ++i) {
224+
EXPECT_EQ(mask1.active_to_full_index(i), mask2.active_to_full_index(i));
225+
}
226+
}
227+
} // namespace
186228
GTEST_TEST(DofMaskTest, Complement) {
187229
const DofMask d1({true, false, false});
188230
const DofMask expected({false, true, true});
189231

190-
EXPECT_EQ(d1.Complement(), expected);
232+
const DofMask d1_complement = d1.Complement();
233+
EXPECT_EQ(d1_complement, expected);
234+
ExpectFullAndActiveIndicesEqual(d1_complement, expected);
191235
}
192236

193237
GTEST_TEST(DofMaskTest, Union) {
194238
const DofMask d1({true, false, false});
195239
const DofMask d2({false, false, true});
196240
const DofMask expected({true, false, true});
197241

198-
EXPECT_EQ(d1.Union(d2), expected);
242+
const DofMask d1_union_d2 = d1.Union(d2);
243+
EXPECT_EQ(d1_union_d2, expected);
244+
ExpectFullAndActiveIndicesEqual(d1_union_d2, expected);
199245
}
200246

201247
GTEST_TEST(DofMaskTest, Intersect) {
202248
const DofMask d1({true, true, false});
203249
const DofMask d2({false, true, true});
204250
const DofMask expected({false, true, false});
205251

206-
EXPECT_EQ(d1.Intersect(d2), expected);
252+
const DofMask d1_intersect_d2 = d1.Intersect(d2);
253+
EXPECT_EQ(d1_intersect_d2, expected);
254+
ExpectFullAndActiveIndicesEqual(d1_intersect_d2, expected);
207255
}
208256

209257
GTEST_TEST(DofMaskTest, Subtract) {
210258
const DofMask d1({true, true, false});
211259
const DofMask d2({false, true, true});
212260
const DofMask expected({true, false, false});
213261

214-
EXPECT_EQ(d1.Subtract(d2), expected);
262+
const DofMask d1_subtract_d2 = d1.Subtract(d2);
263+
EXPECT_EQ(d1_subtract_d2, expected);
264+
ExpectFullAndActiveIndicesEqual(d1_subtract_d2, expected);
215265
}
216266

217267
GTEST_TEST(DofMaskTest, GetFromArrayWithReturn) {

0 commit comments

Comments
 (0)