diff --git a/common/autodiff_overloads.h b/common/autodiff_overloads.h index e5cc1f16c3b7..f747f990f68d 100644 --- a/common/autodiff_overloads.h +++ b/common/autodiff_overloads.h @@ -121,13 +121,19 @@ pow(const Eigen::AutoDiffScalar& base, typename internal::remove_all::type::PlainObject, typename internal::remove_all::type::PlainObject>, "The derivative types must match."); - - internal::make_coherent(base.derivatives(), exponent.derivatives()); + using DerType = typename internal::remove_all::type::PlainObject; const auto& x = base.value(); - const auto& xgrad = base.derivatives(); + const DerType& xgrad = base.derivatives(); const auto& y = exponent.value(); - const auto& ygrad = exponent.derivatives(); + const DerType& ygrad = exponent.derivatives(); + + // Make the derivative sizes coherenent. + if (xgrad.size() == 0 && ygrad.size() > 0) { + return pow(MakeAutoDiffScalar(x, DerType::Zero(ygrad.size())), exponent); + } else if (xgrad.size() > 0 && ygrad.size() == 0) { + return pow(base, MakeAutoDiffScalar(y, DerType::Zero(xgrad.size()))); + } using std::log; using std::pow;