diff --git a/xprof/convert/xplane_to_op_stats.cc b/xprof/convert/xplane_to_op_stats.cc index e49abbda..db6cffbc 100644 --- a/xprof/convert/xplane_to_op_stats.cc +++ b/xprof/convert/xplane_to_op_stats.cc @@ -283,7 +283,8 @@ DutyCycleTracker ConstructDutyCycleTracker(XPlaneVisitor& visitor) { visitor.ForEachLine([&](const XLineVisitor& line) { if (line.Name() == tsl::profiler::kXlaOpLineName) { line.ForEachEvent([&](const XEventVisitor& event) { - auto hlo_category_stat = event.GetStat(StatType::kHloCategory); + auto hlo_category_stat = + event.Metadata().GetStat(StatType::kHloCategory); duty_cycle_tracker.AddInterval( event.GetTimespan(), !(hlo_category_stat && diff --git a/xprof/convert/xplane_to_op_stats_test.cc b/xprof/convert/xplane_to_op_stats_test.cc index 4f10c439..8b6ac90c 100644 --- a/xprof/convert/xplane_to_op_stats_test.cc +++ b/xprof/convert/xplane_to_op_stats_test.cc @@ -658,17 +658,22 @@ TEST(ConvertXPlaneToOpStats, ConstructDutyCycleTrackerFromXlaOps) { XPlaneBuilder device_plane_builder(device_plane); XLineBuilder op_line = device_plane_builder.GetOrCreateLine(0); op_line.SetName(kXlaOpLineName); + CreateXEventMetadata(&device_plane_builder, "op.1", + {{StatType::kHloCategory, tsl::profiler::kHloInfeed}}); CreateXEvent(&device_plane_builder, &op_line, "op.1", /*offset_ps=*/10, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloInfeed}}); + /*duration_ps=*/10); + CreateXEventMetadata(&device_plane_builder, "op.2", + {{StatType::kHloCategory, tsl::profiler::kHloCall}}); CreateXEvent(&device_plane_builder, &op_line, "op.2", /*offset_ps=*/20, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloCall}}); + /*duration_ps=*/10); + CreateXEventMetadata(&device_plane_builder, "op.3", + {{StatType::kHloCategory, tsl::profiler::kHloCall}}); CreateXEvent(&device_plane_builder, &op_line, "op.3", /*offset_ps=*/30, /*duration_ps=*/10); + CreateXEventMetadata(&device_plane_builder, "op.4", + {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); CreateXEvent(&device_plane_builder, &op_line, "op.4", /*offset_ps=*/40, - /*duration_ps=*/10, - {{StatType::kHloCategory, tsl::profiler::kHloOutfeed}}); + /*duration_ps=*/10); XLineBuilder xla_module_line = device_plane_builder.GetOrCreateLine(1); xla_module_line.SetName(kXlaModuleLineName); CreateXEvent(&device_plane_builder, &xla_module_line, "module.1",