diff --git a/plugin/xprof/protobuf/op_metrics.proto b/plugin/xprof/protobuf/op_metrics.proto index fac258722..e35a03199 100644 --- a/plugin/xprof/protobuf/op_metrics.proto +++ b/plugin/xprof/protobuf/op_metrics.proto @@ -107,8 +107,9 @@ message OpMetrics { string name = 6; // Long name of this op (e.g., HLO expression). string long_name = 20; - // Category of this op. (e.g. Hlo op category, Framework op type) - // Could be parsed from provenance if it is a framework op. + // Category of this op. (e.g. Hlo op category, Framework op type, input + // pipeline stage category) Could be parsed from provenance if it is a + // framework op. string category = 11; // Provenance of this op if it is an HLO Op. (e.g. TF Op name, JAX Op name) // TODO(b/310434797) Extends this for JAX as now only TF Op is populated. diff --git a/xprof/convert/xplane_to_op_metrics_db.cc b/xprof/convert/xplane_to_op_metrics_db.cc index 67ce66281..bc5186b85 100644 --- a/xprof/convert/xplane_to_op_metrics_db.cc +++ b/xprof/convert/xplane_to_op_metrics_db.cc @@ -121,9 +121,14 @@ void ProcessOneTfActivity(const TfActivity& activity, } tsl::profiler::Timespan tf_op_span = tsl::profiler::PicoSpan( info->start_timestamp_ps, activity.timestamp_ps); + // Note the tf_op.id will be used as the hlo_module_id in EnterOp when + // constructing the op metrics db. + // - not set for legacy TfOp: behavior unchanged with hlo_module_id=0 + // - for input pipeline ops, this is the stage id. tf_metrics_data->tf_metrics_db_builder.EnterOp( activity.tf_op.name, activity.tf_op.type, activity.is_eager, - tf_op_span.duration_ps(), info->children_duration_ps); + tf_op_span.duration_ps(), info->children_duration_ps, + activity.tf_op.id); TfOpInfo* parent_info = tf_op_stack->Top(); if (parent_info != nullptr) { parent_info->children_duration_ps += tf_op_span.duration_ps(); @@ -161,56 +166,44 @@ void CollectTfActivities( uint32 tf_op_id = 0; if (tsl::profiler::IsDerivedThreadId(line.Id())) return; tf_activities->reserve(line.NumEvents() * 2); - line.ForEachEvent([&tf_ops, &tf_op_id, - &tf_activities](const XEventVisitor& event) { - const tsl::profiler::TfOp* tf_op = tsl::gtl::FindOrNull(tf_ops, event.Id()); - if (tf_op != nullptr) { - ++tf_op_id; - bool is_eager = false; - if (std::optional stat = - event.GetStat(StatType::kIsEager)) { - is_eager = stat->IntValue(); - } - tsl::profiler::Timespan span = event.GetTimespan(); - tf_activities->push_back( - {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager}); - tf_activities->push_back( - {span.end_ps(), tf_op_id, kTfOpEnd, *tf_op, is_eager}); - } - if (auto tf_op_stat = event.GetStat(StatType::kTfOp); - tf_op_stat.has_value()) { - ++tf_op_id; - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(tf_op_stat->StrOrRefValue()); - tsl::profiler::Timespan span = event.GetTimespan(); - tf_activities->push_back( - {span.begin_ps(), tf_op_id, kTfOpBegin, tf_op, false}); - tf_activities->push_back( - {span.end_ps(), tf_op_id, kTfOpEnd, tf_op, false}); - } - }); + line.ForEachEvent( + [&tf_ops, &tf_op_id, &tf_activities](const XEventVisitor& event) { + auto id = event.Id(); + // Add id override for input pipeline ops. + if (const auto& stat = event.GetStat(StatType::kInputPipelineStageId); + stat.has_value()) { + id = stat->IntValue(); + } + const tsl::profiler::TfOp* tf_op = tsl::gtl::FindOrNull(tf_ops, id); + if (tf_op != nullptr) { + ++tf_op_id; + bool is_eager = false; + if (std::optional stat = + event.GetStat(StatType::kIsEager)) { + is_eager = stat->IntValue(); + } + tsl::profiler::Timespan span = event.GetTimespan(); + tf_activities->push_back( + {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager}); + tf_activities->push_back( + {span.end_ps(), tf_op_id, kTfOpEnd, *tf_op, is_eager}); + } + if (auto tf_op_stat = event.GetStat(StatType::kTfOp); + tf_op_stat.has_value()) { + ++tf_op_id; + tsl::profiler::TfOp tf_op = + tsl::profiler::ParseTfOpFullname(tf_op_stat->StrOrRefValue()); + tsl::profiler::Timespan span = event.GetTimespan(); + tf_activities->push_back( + {span.begin_ps(), tf_op_id, kTfOpBegin, tf_op, false}); + tf_activities->push_back( + {span.end_ps(), tf_op_id, kTfOpEnd, tf_op, false}); + } + }); } } // namespace -absl::flat_hash_map -CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace) { - absl::flat_hash_map tf_ops; - for (const auto& id_metadata : host_trace.event_metadata()) { - const XEventMetadata& metadata = id_metadata.second; - // On the host, we have added some user-specified TraceMe's in addition to - // the TraceMe's added to every TensorFlow op by the system. These - // user-inserted TraceMe's have "unknown" type. We don't count them in - // Tf-stats. - tsl::profiler::TfOp tf_op = - tsl::profiler::ParseTfOpFullname(metadata.name()); - if (tf_op.category != tsl::profiler::Category::kUnknown) { - tf_ops.try_emplace(metadata.id(), tf_op); - } - } - return tf_ops; -} - TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( const XLineVisitor& line, const absl::flat_hash_map& tf_ops) { @@ -229,11 +222,60 @@ void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst) { src.tf_metrics_db.Clear(); } +absl::flat_hash_map +CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace) { + absl::flat_hash_map tf_ops; + XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&host_trace); + plane.ForEachLine([&tf_ops](const XLineVisitor& line) { + line.ForEachEvent( + [&tf_ops](const XEventVisitor& event) { + // 1. Newly added input pipeline ops processing: identified by the + // stage id and category. + auto input_pipeline_stage_id = + event.GetStat(StatType::kInputPipelineStageId); + if (input_pipeline_stage_id.has_value()) { + auto input_pipeline_stage_category = + event.GetStat(StatType::kInputPipelineStageCategory); + // Note that main thread traceme events are also identified by the + // stage id and name, but they are not doing the actual work,so we + // are not setting the type for it. Only worker threads has type set + // to reflect stage category (read, preprocessing, enqueue, unknown + // ). + if (input_pipeline_stage_category.has_value()) { + tsl::profiler::TfOp tf_op = tsl::profiler::ParseTfOpFullname( + event.Name(), tsl::profiler::Category::kInputPipeline, + input_pipeline_stage_category->StrOrRefValue(), + input_pipeline_stage_id->IntValue()); + // Note using input pipeline stage id as unique identifier here + // instead of events id, because event id's uniqueness is bind + // with the event name string due to nature of xplane event + // metadata creation, making it a non-sufficient identifier when + // building an input pipeline event stack. + tf_ops.try_emplace(input_pipeline_stage_id->IntValue(), tf_op); + } + return; + } + + // 2. Fallback to legacy host ops processing. + // On the host, we have added some user-specified TraceMe's in + // addition to the TraceMe's added to every TensorFlow op by the + // system. These user-inserted TraceMe's have "unknown" type. We don't + // count them in Tf-stats. + tsl::profiler::TfOp tf_op = + tsl::profiler::ParseTfOpFullname(event.Name()); + if (tf_op.category != tsl::profiler::Category::kUnknown) { + tf_ops.try_emplace(event.Id(), tf_op); + } + }); + }); + return tf_ops; +} + OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace) { - absl::flat_hash_map tf_ops = - CollectTfOpsFromHostThreadsXPlane(host_trace); OpMetricsDb result; OpMetricsDbCombiner combiner(&result); + absl::flat_hash_map tf_ops = + CollectTfOpsFromHostThreadsXPlane(host_trace); XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&host_trace); plane.ForEachLine([&tf_ops, &combiner](const XLineVisitor& line) { ConsumeTfMetricsDbData( diff --git a/xprof/convert/xplane_to_op_metrics_db.h b/xprof/convert/xplane_to_op_metrics_db.h index 3d3c8e434..53c59bcbd 100644 --- a/xprof/convert/xplane_to_op_metrics_db.h +++ b/xprof/convert/xplane_to_op_metrics_db.h @@ -44,13 +44,13 @@ struct TfMetricsDbData { HostOpMetricsDbBuilder tf_metrics_db_builder{&tf_metrics_db}; }; -absl::flat_hash_map -CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace); - TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( const XLineVisitor& line, const absl::flat_hash_map& tf_ops); +absl::flat_hash_map +CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace); + void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst); OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace); diff --git a/xprof/convert/xplane_to_op_metrics_db_test.cc b/xprof/convert/xplane_to_op_metrics_db_test.cc index 28e9ab097..5cc99bf10 100644 --- a/xprof/convert/xplane_to_op_metrics_db_test.cc +++ b/xprof/convert/xplane_to_op_metrics_db_test.cc @@ -54,6 +54,7 @@ using ::tsl::profiler::XStatsBuilder; #if defined(PLATFORM_GOOGLE) // NOLINTNEXTLINE: clang-tidy missing-includes using ::testing::EqualsProto; +using ::testing::proto::IgnoringRepeatedFieldOrdering; #endif void AddTensorFlowTpuOpEvent(std::string&& name, std::string&& tf_op_fullname, @@ -99,6 +100,23 @@ void AddTensorFlowOpEvent(std::string&& tf_op_fullname, *plane->GetOrCreateStatMetadata(std::move(tf_op_fullname))); } +void AddInputPipelineTracemeEvent(std::string&& name, + int64_t start_timestamp_ns, + int64_t duration_ns, + absl::string_view stage_category, + int64_t stage_id, XPlaneBuilder* plane, + XLineBuilder* line) { + XEventBuilder event = line->AddEvent(*plane->GetOrCreateEventMetadata(name)); + event.SetTimestampNs(start_timestamp_ns); + event.SetDurationNs(duration_ns); + event.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kInputPipelineStageId)), + stage_id); + event.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kInputPipelineStageCategory)), + move(stage_category)); +} + void AddXlaCpuOpEvent(std::string&& hlo_op_name, std::string&& tf_op, int64_t start_timestamp_ns, int64_t duration_ns, XPlaneBuilder* plane, XLineBuilder* line) { @@ -310,6 +328,67 @@ TEST(ConvertXPlaneToOpMetricsDb, HostXPlaneWithXlaOps) { #endif } +TEST(ConvertXPlaneToOpMetricsDb, HostXPlaneWithInputPipelineTracemeOps) { + XPlane xplane; + XPlaneBuilder plane(&xplane); + XLineBuilder line = plane.GetOrCreateLine(/*line_id=*/10); + AddInputPipelineTracemeEvent("ShuffleMapDataset", 100000, 10000, + "preprocessing", 1, &plane, &line); + AddInputPipelineTracemeEvent("MapMapDataset", 100000, 8000, "preprocessing", + 2, &plane, &line); + AddInputPipelineTracemeEvent("ShuffleMapDataset", 120000, 10000, + "preprocessing", 3, &plane, &line); + AddInputPipelineTracemeEvent("MapMapDataset", 120000, 8000, "preprocessing", + 4, &plane, &line); + + OpMetricsDb op_metrics = ConvertHostThreadsXPlaneToOpMetricsDb(xplane); +#if defined(PLATFORM_GOOGLE) + EXPECT_THAT(op_metrics, IgnoringRepeatedFieldOrdering( + EqualsProto(R"pb(metrics_db { + self_time_ps: 2000000 + occurrences: 1 + name: "ShuffleMapDataset" + category: "preprocessing" + hlo_module_id: 1 + time_ps: 10000000 + } + metrics_db { + self_time_ps: 8000000 + occurrences: 1 + name: "MapMapDataset" + category: "preprocessing" + hlo_module_id: 2 + time_ps: 8000000 + } + metrics_db { + self_time_ps: 2000000 + occurrences: 1 + name: "ShuffleMapDataset" + category: "preprocessing" + hlo_module_id: 3 + time_ps: 10000000 + } + metrics_db { + self_time_ps: 8000000 + occurrences: 1 + name: "MapMapDataset" + category: "preprocessing" + hlo_module_id: 4 + time_ps: 8000000 + } + metrics_db { + self_time_ps: 10000000 + name: "IDLE" + time_ps: 10000000 + category: "IDLE" + } + total_time_ps: 30000000 + total_op_time_ps: 20000000 + precision_stats {} + )pb"))); +#endif +} + TEST(ConvertXPlaneToOpMetricsDb, DeviceOpMetricsDbWithNullPerformanceInfo) { std::string hlo_string = R"( HloModule TestModule diff --git a/xprof/convert/xplane_to_op_stats.cc b/xprof/convert/xplane_to_op_stats.cc index a025568ad..a40fe3399 100644 --- a/xprof/convert/xplane_to_op_stats.cc +++ b/xprof/convert/xplane_to_op_stats.cc @@ -544,6 +544,7 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, const XPlane* host_plane = tsl::profiler::FindPlaneWithName( space, tsl::profiler::kHostThreadsPlaneName); if (host_plane) { + // TODO(yinzz): support legacy analysis path too? if (options.generate_op_metrics_db) { *op_stats.mutable_host_op_metrics_db() = ConvertHostThreadsXPlaneToOpMetricsDb(*host_plane); diff --git a/xprof/utils/op_utils.cc b/xprof/utils/op_utils.cc index 775a7369c..1265eaeec 100644 --- a/xprof/utils/op_utils.cc +++ b/xprof/utils/op_utils.cc @@ -99,10 +99,12 @@ void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, void HostOpMetricsDbBuilder::EnterOp(absl::string_view name, absl::string_view category, bool is_eager, - uint64 time_ps, uint64 children_time_ps) { + uint64 time_ps, uint64 children_time_ps, + int64_t id) { uint64 self_time_ps = time_ps - children_time_ps; DCHECK_GE(time_ps, self_time_ps); - OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(/*hlo_module_id=*/0, name); + OpMetrics* op_metrics = + LookupOrInsertNewOpMetrics(/*hlo_module_id=*/id, name); if (op_metrics->category().empty()) op_metrics->set_category(category.data(), category.size()); op_metrics->set_num_cores(1); diff --git a/xprof/utils/op_utils.h b/xprof/utils/op_utils.h index d17b7ac53..426ba8c18 100644 --- a/xprof/utils/op_utils.h +++ b/xprof/utils/op_utils.h @@ -63,8 +63,11 @@ class HostOpMetricsDbBuilder : public OpMetricsDbBuilder { // the execution time of its children. // children_time_ps = the execution time of the children of this OP in // picoseconds + // id = host op uniqueness identifier. For input pipeline ops, this is the + // stage id. By default is 0 if not needed. void EnterOp(absl::string_view name, absl::string_view category, - bool is_eager, uint64 time_ps, uint64 children_time_ps); + bool is_eager, uint64 time_ps, uint64 children_time_ps, + int64_t id = 0); // Updates total_host_infeed_enq_duration_ps_ and // total_host_infeed_enq_duration_ps_.