diff --git a/xprof/convert/megascale_perfetto/BUILD b/xprof/convert/megascale_perfetto/BUILD index ce121dd09..7083bc090 100644 --- a/xprof/convert/megascale_perfetto/BUILD +++ b/xprof/convert/megascale_perfetto/BUILD @@ -67,5 +67,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_googlesource_code_re2//:re2", + "@xla//xla/tsl/lib/gtl:map_util", ], ) diff --git a/xprof/convert/megascale_perfetto/trace_processor.cc b/xprof/convert/megascale_perfetto/trace_processor.cc index a979b4a8c..ec7b61350 100644 --- a/xprof/convert/megascale_perfetto/trace_processor.cc +++ b/xprof/convert/megascale_perfetto/trace_processor.cc @@ -19,6 +19,7 @@ #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "re2/re2.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xprof/convert/megascale_perfetto/xprof_trace.h" namespace xprof::megascale { @@ -110,16 +111,25 @@ absl::string_view ExtractRendezvousFromGraphKey(const Event& event, return ""; } -int64_t ExtractHloId(absl::string_view name, absl::string_view prefix) { - if (absl::ConsumePrefix(&name, prefix)) { - int64_t id; - size_t end = 0; - while (end < name.size() && std::isdigit(name[end])) { - end++; - } - if (end > 0 && absl::SimpleAtoi(name.substr(0, end), &id)) { - return id; - } +int ExtractMultiRecvCountFromLongName(const Event& event, + const XprofTrace& trace) { + int count = 1; // default + absl::string_view long_name; + if (!FindArgString(event, trace, "long_name", &long_name)) { + return count; + } + static constexpr LazyRE2 kMultiRecvCountRe = { + R"re(_xla_multi_recv_count="(\d+)")re"}; + // If successful, it will overwrite `count`. + RE2::PartialMatch(long_name, *kMultiRecvCountRe, &count); + return count; +} + +int64_t ExtractHloId(absl::string_view name) { + static constexpr LazyRE2 kHloId = {R"re([a-zA-Z_-]+\.(\d+))re"}; + int64_t id; + if (RE2::FullMatch(name, *kHloId, &id)) { + return id; } return -1; } @@ -177,7 +187,7 @@ class FlowQueueMap { void TraceProcessor::Process() { SortEvents(); AssignRunIds(); - MarkLastH2DEvents(); + MarkLastDmaEvents(); ResolveFlows(); AddNetworkCounters(); ModifyTrackNames(); @@ -247,7 +257,7 @@ void TraceProcessor::AssignRunIds() { assign_ids(trace_.megascale_fragments); } -void TraceProcessor::MarkLastH2DEvents() { +void TraceProcessor::MarkLastDmaEvents() { static constexpr LazyRE2 kExecutionEventRe = {R"(device_\d+_gid_.*)"}; for (auto& [tpu_id, tracks] : trace_.megascale_fragments) { for (auto& track : tracks) { @@ -263,6 +273,7 @@ void TraceProcessor::MarkLastH2DEvents() { int64_t execution_end_ps = exec_event.timestamp_ps + exec_event.duration_ps; Event* last_h2d_event = nullptr; + Event* last_d2h_event = nullptr; for (size_t j = i + 1; j < track.events.size(); ++j) { if (track.events[j].timestamp_ps >= execution_end_ps) { @@ -271,11 +282,18 @@ void TraceProcessor::MarkLastH2DEvents() { if (track.events[j].name == "HostToDevice END") { last_h2d_event = &track.events[j]; } + if (track.events[j].name == "DeviceToHost END") { + last_d2h_event = &track.events[j]; + } } if (last_h2d_event != nullptr) { last_h2d_event->args.push_back( - {trace_.string_table.Intern("is_last_h2d"), int64_t{1}}); + {trace_.string_table.Intern("is_last_instance"), int64_t{1}}); + } + if (last_d2h_event != nullptr) { + last_d2h_event->args.push_back( + {trace_.string_table.Intern("is_last_instance"), int64_t{1}}); } } } @@ -286,14 +304,17 @@ void TraceProcessor::ResolveFlows() { // map key: tpu_id, run_id, hlo_id absl::flat_hash_map send_hlo_to_rendezvous; absl::flat_hash_map recv_hlo_to_rendezvous; + absl::flat_hash_map recv_hlo_to_multi_count; // Separate queues for different consumer types to ensure the right ID goes to // the right event type, even if counts mismatch slightly. FlowQueueMap q_send_to_d2h; FlowQueueMap q_d2h_to_send_done; - FlowQueueMap q_send_done_to_h2d; FlowQueueMap q_h2d_to_recv_done; FlowQueueMap q_recv_to_recv_done; + FlowQueueMap q_recv_to_h2d; + FlowQueueMap q_send_to_send_done; + FlowQueueMap q_send_to_recv_done; auto make_key = [](int64_t tpu, int64_t run, absl::string_view rendezvous) { return absl::StrCat(tpu, "_", run, "_", rendezvous); @@ -302,9 +323,26 @@ void TraceProcessor::ResolveFlows() { return absl::StrCat(tpu, "_", run, "_", hlo_id); }; - // Helper to visit TPU events of specific types + auto push_flow = [&](FlowQueueMap& queue, const std::string& key, + Event& event) { + int64_t fid = next_flow_id_++; + event.flows.push_back({fid, FlowDirection::kSource}); + event.args.push_back({trace_.string_table.Intern("flow_out"), fid}); + queue.Push(key, fid); + }; + + auto pop_flow = [&](FlowQueueMap& queue, const std::string& key, + Event& event) { + int64_t fid = queue.Pop(key); + if (fid != -1) { + event.args.push_back({trace_.string_table.Intern("flow_in"), fid}); + event.flows.push_back({fid, FlowDirection::kSink}); + } + return fid; + }; + + // Helper to visit XLA ops in TPU fragments. auto visit_tpu_ops = [&](auto visitor) { - // Visit TPU Fragments for (auto& [tpu_id, tracks] : trace_.tpu_fragments) { for (auto& track : tracks) { if (!absl::StrContains(track.name, "XLA Ops")) { @@ -320,9 +358,8 @@ void TraceProcessor::ResolveFlows() { } }; - // Helper to visit Megascale events of specific types + // Helper to visit all events in Megascale fragments. auto visit_megascale = [&](auto visitor) { - // Visit Megascale Fragments for (auto& [tpu_id, tracks] : trace_.megascale_fragments) { for (auto& track : tracks) { for (auto& event : track.events) { @@ -336,197 +373,170 @@ void TraceProcessor::ResolveFlows() { }; // --------------------------------------------------------------------------- - // Pass 1: 'send' and 'recv' events + // Pass 1: Producers // --------------------------------------------------------------------------- + // TPU Events: 'send' and 'recv'. visit_tpu_ops([&](int64_t tpu_id, Track& track, Event& event) { + bool is_send; if (absl::StartsWith(event.name, "send.")) { - int64_t hlo_id = ExtractHloId(event.name, "send."); - if (hlo_id == -1) { - return; - } - absl::string_view rendezvous = - ExtractRendezvousFromLongName(event, trace_); - if (rendezvous.empty()) { - return; - } - send_hlo_to_rendezvous[make_hlo_key(tpu_id, event.run_id, hlo_id)] = - rendezvous; - - int64_t flow_id = next_flow_id_++; - - // Add flow to the send event itself (Source) - event.flows.push_back({flow_id, FlowDirection::kSource}); - event.args.push_back({trace_.string_table.Intern("flow_out"), flow_id}); - q_send_to_d2h.Push(make_key(tpu_id, event.run_id, rendezvous), flow_id); + is_send = true; } else if (absl::StartsWith(event.name, "recv.")) { - int64_t hlo_id = ExtractHloId(event.name, "recv."); - if (hlo_id == -1) { - return; - } - absl::string_view rendezvous = - ExtractRendezvousFromLongName(event, trace_); - if (rendezvous.empty()) { - return; - } - recv_hlo_to_rendezvous[make_hlo_key(tpu_id, event.run_id, hlo_id)] = - rendezvous; - - int64_t flow_id = next_flow_id_++; - event.flows.push_back({flow_id, FlowDirection::kSource}); - event.args.push_back({trace_.string_table.Intern("flow_out"), flow_id}); - q_recv_to_recv_done.Push(make_key(tpu_id, event.run_id, rendezvous), - flow_id); + is_send = false; + } else { + return; // Skip other events. } - }); - // --------------------------------------------------------------------------- - // Pass 2: Consumers (D2H, send-done, H2D, recv-done) - // --------------------------------------------------------------------------- - // D2H consumes from send, produces for send-done - visit_megascale([&](int64_t tpu_id, Event& event) { - if (event.name == "DeviceToHost START") { - int64_t action_idx = -1; - if (!FindArgInt(event, trace_, "action_index", &action_idx) || - action_idx != 0) { - return; - } - absl::string_view rendezvous = - ExtractRendezvousFromGraphKey(event, trace_); - if (rendezvous.empty()) { - return; - } - const std::string key = make_key(tpu_id, event.run_id, rendezvous); - int64_t fid_in = q_send_to_d2h.Pop(key); - if (fid_in == -1) { - return; - } - event.args.push_back({trace_.string_table.Intern("flow_in"), fid_in}); - event.flows.push_back({fid_in, FlowDirection::kSink}); - - int64_t fid_out = next_flow_id_++; - event.args.push_back({trace_.string_table.Intern("flow_out"), fid_out}); - event.flows.push_back({fid_out, FlowDirection::kSource}); - q_d2h_to_send_done.Push(key, fid_out); - } - }); - // send-done consumes from D2H, produces for H2D - visit_tpu_ops([&](int64_t tpu_id, Track& track, Event& event) { - if (absl::StartsWith(event.name, "send-done.")) { - int64_t hlo_id = ExtractHloId(event.name, "send-done."); - if (hlo_id == -1) { - return; - } - auto it = send_hlo_to_rendezvous.find( - make_hlo_key(tpu_id, event.run_id, hlo_id)); - if (it == send_hlo_to_rendezvous.end()) { - return; - } - absl::string_view rendezvous = it->second; - const std::string key = make_key(tpu_id, event.run_id, rendezvous); - int64_t fid_in = q_d2h_to_send_done.Pop(key); - if (fid_in == -1) { - return; - } - event.args.push_back({trace_.string_table.Intern("flow_in"), fid_in}); - event.flows.push_back({fid_in, FlowDirection::kSink}); - - int64_t fid_out = next_flow_id_++; - event.args.push_back({trace_.string_table.Intern("flow_out"), fid_out}); - event.flows.push_back({fid_out, FlowDirection::kSource}); - q_send_done_to_h2d.Push(key, fid_out); + int64_t hlo_id = ExtractHloId(event.name); + if (hlo_id == -1) { + return; } - }); - // H2D consumes from send-done, produces for recv-done - visit_megascale([&](int64_t tpu_id, Event& event) { - if (event.name != "HostToDevice END") { + absl::string_view rendezvous = ExtractRendezvousFromLongName(event, trace_); + if (rendezvous.empty()) { return; } + std::string key = make_key(tpu_id, event.run_id, rendezvous); + std::string hlo_key = make_hlo_key(tpu_id, event.run_id, hlo_id); + + if (is_send) { + send_hlo_to_rendezvous[hlo_key] = rendezvous; + push_flow(q_send_to_send_done, key, event); + push_flow(q_send_to_recv_done, key, event); + push_flow(q_send_to_d2h, key, event); + } else { + recv_hlo_to_rendezvous[hlo_key] = rendezvous; + recv_hlo_to_multi_count[hlo_key] = + ExtractMultiRecvCountFromLongName(event, trace_); + push_flow(q_recv_to_recv_done, key, event); + push_flow(q_recv_to_h2d, key, event); + } + }); - int64_t is_last_h2d = 0; - FindArgInt(event, trace_, "is_last_h2d", &is_last_h2d); - if (is_last_h2d != 1) { - return; + // Megascale DMA events: D2H and H2D. + // Note: We also process the consumers here since we have all necessary info. + visit_megascale([&](int64_t tpu_id, Event& event) { + bool is_d2h; + bool is_start; + if (event.name == "DeviceToHost START") { + is_d2h = true; + is_start = true; + } else if (event.name == "DeviceToHost END") { + is_d2h = true; + is_start = false; + } else if (event.name == "HostToDevice START") { + is_d2h = false; + is_start = true; + } else if (event.name == "HostToDevice END") { + is_d2h = false; + is_start = false; + } else { + return; // Skip this event. } absl::string_view rendezvous = ExtractRendezvousFromGraphKey(event, trace_); if (rendezvous.empty()) { return; } + std::string key = make_key(tpu_id, event.run_id, rendezvous); - // The outgoing flow event does not work correctly if the end time of this - // event is exactly the same as the end event of the megascale graph event - // it is nested order. To fix that, reduce the duration of this event by a - // nanosecond. - event.duration_ps = std::max(event.duration_ps - 1000, int64_t{0}); - - const std::string key = make_key(tpu_id, event.run_id, rendezvous); - int64_t fid_in = q_send_done_to_h2d.Pop(key); - if (fid_in == -1) { - LOG_EVERY_N(WARNING, 1000) << "Failed to find flow for H2D: " << key; - return; - } - event.args.push_back({trace_.string_table.Intern("flow_in"), fid_in}); - event.flows.push_back({fid_in, FlowDirection::kSink}); - - int64_t fid_out = next_flow_id_++; - event.args.push_back({trace_.string_table.Intern("flow_out"), fid_out}); - event.flows.push_back({fid_out, FlowDirection::kSource}); - q_h2d_to_recv_done.Push(key, fid_out); - }); - // recv-done consumes from H2D - visit_tpu_ops([&](int64_t tpu_id, Track& track, Event& event) { - if (absl::StartsWith(event.name, "recv-done.")) { - int64_t hlo_id = ExtractHloId(event.name, "recv-done."); - if (hlo_id == -1) { + if (is_start) { + // We're only interested in the first START event. + if (int64_t action_idx = -1; + !FindArgInt(event, trace_, "action_index", &action_idx) || + action_idx != 0) { return; } - auto it = recv_hlo_to_rendezvous.find( - make_hlo_key(tpu_id, event.run_id, hlo_id)); - if (it == recv_hlo_to_rendezvous.end()) { + if (is_d2h) { + pop_flow(q_send_to_d2h, key, event); + } else { + pop_flow(q_recv_to_h2d, key, event); + } + } else { // END event + // We're only interested in the last END event. + if (int64_t is_last = 0; + !FindArgInt(event, trace_, "is_last_instance", &is_last) || + is_last != 1) { return; } - absl::string_view rendezvous = it->second; - const std::string key = make_key(tpu_id, event.run_id, rendezvous); - - // Recv-done consumes from recv. - int64_t fid_in_recv = q_recv_to_recv_done.Pop(key); - if (fid_in_recv != -1) { - event.args.push_back( - {trace_.string_table.Intern("flow_in"), fid_in_recv}); - event.flows.push_back({fid_in_recv, FlowDirection::kSink}); + if (!is_d2h) { + // Reduce last H2D END duration slightly. This is needed because + // Perfetto will show the flow going out of the parent slice (megascale + // graph event) if this is the last event in the action graph and its + // end time matches the end time of the parent slice. + event.duration_ps = std::max(event.duration_ps - 1000, int64_t{0}); + } + if (is_d2h) { + push_flow(q_d2h_to_send_done, key, event); + } else { + push_flow(q_h2d_to_recv_done, key, event); } + } + }); - // Recv-done consumes from H2D. - int64_t fid_in = q_h2d_to_recv_done.Pop(key); + // --------------------------------------------------------------------------- + // Pass 2: Consumers (send-done, recv-done) + // --------------------------------------------------------------------------- + visit_tpu_ops([&](int64_t tpu_id, Track& track, Event& event) { + bool is_send_done; + if (absl::StartsWith(event.name, "send-done.")) { + is_send_done = true; + } else if (absl::StartsWith(event.name, "recv-done.")) { + is_send_done = false; + } else { + return; // Skip other events. + } - // Flow events from recv-done to recv-done END. - int64_t fid_internal = next_flow_id_++; - event.args.push_back( - {trace_.string_table.Intern("flow_out"), fid_internal}); - event.flows.push_back({fid_internal, FlowDirection::kSource}); + int64_t hlo_id = ExtractHloId(event.name); + if (hlo_id == -1) { + return; + } + std::string hlo_key = make_hlo_key(tpu_id, event.run_id, hlo_id); + absl::string_view rendezvous = + is_send_done + ? tsl::gtl::FindWithDefault(send_hlo_to_rendezvous, hlo_key, "") + : tsl::gtl::FindWithDefault(recv_hlo_to_rendezvous, hlo_key, ""); + if (rendezvous.empty()) return; + std::string key = make_key(tpu_id, event.run_id, rendezvous); + + if (is_send_done) { + pop_flow(q_send_to_send_done, key, event); + pop_flow(q_d2h_to_send_done, key, event); + } else { + pop_flow(q_send_to_recv_done, key, event); + pop_flow(q_recv_to_recv_done, key, event); + + // For H2D to recv-done, we may have multiple. + int multi_count = + tsl::gtl::FindWithDefault(recv_hlo_to_multi_count, hlo_key, 1); + + // Instead of attaching the flows to the recv-done event, let's create an + // instant "recv-done END" event that begins right after the recv-done + // finishes and attach the flows to it. We do this because the H2D events + // may end after the recv-done has started and Perfetto does not handle + // that case well. (end time of producer slice is later than start time of + // consumer slice) - // Create an instant event to mark the end of the recv-done. Event instant_event; instant_event.name = absl::StrCat(event.name, " END"); - // For the flow event from "recv-done" to "recv-done END" to work, the - // latter event must begin after the former finishes. We set the start - // time of the END event to be 1 ns after the end of the recv-done. A - // nanosecond is the smallest supported time unit in Perfetto. instant_event.timestamp_ps = event.timestamp_ps + event.duration_ps + 1000; instant_event.duration_ps = 0; instant_event.run_id = event.run_id; instant_event.args.push_back( - {trace_.string_table.Intern("flow_in"), fid_in}); + {trace_.string_table.Intern("run_id"), event.run_id}); + // Add a flow from the recv-done to the recv-done END. + int64_t fid_internal = next_flow_id_++; + event.args.push_back( + {trace_.string_table.Intern("flow_out"), fid_internal}); + event.flows.push_back({fid_internal, FlowDirection::kSource}); instant_event.args.push_back( {trace_.string_table.Intern("flow_in"), fid_internal}); - instant_event.args.push_back( - {trace_.string_table.Intern("run_id"), event.run_id}); - if (fid_in != -1) { - instant_event.flows.push_back({fid_in, FlowDirection::kSink}); - } instant_event.flows.push_back({fid_internal, FlowDirection::kSink}); - track.events.push_back(instant_event); + + for (int i = 0; i < multi_count; ++i) { + pop_flow(q_h2d_to_recv_done, key, instant_event); + } + + track.events.push_back(std::move(instant_event)); } }); } diff --git a/xprof/convert/megascale_perfetto/trace_processor.h b/xprof/convert/megascale_perfetto/trace_processor.h index 069835717..e9af4e19f 100644 --- a/xprof/convert/megascale_perfetto/trace_processor.h +++ b/xprof/convert/megascale_perfetto/trace_processor.h @@ -29,9 +29,9 @@ class TraceProcessor { // don't assign run IDs to program runs that were not fully captured in the // profile. void AssignRunIds(); - // Marks the last H2D events for each execution event. This is needed for - // adding H2D -> recv-done flows. - void MarkLastH2DEvents(); + // Marks the last D2H and H2D events for each action graph execution. This is + // needed for adding flows. + void MarkLastDmaEvents(); // Resolves flows between TPU events and Megascale events. void ResolveFlows(); // Adds a counter track for network metrics. diff --git a/xprof/convert/megascale_perfetto/xspace_loader.cc b/xprof/convert/megascale_perfetto/xspace_loader.cc index eb551c9b2..0d1971e93 100644 --- a/xprof/convert/megascale_perfetto/xspace_loader.cc +++ b/xprof/convert/megascale_perfetto/xspace_loader.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -126,6 +127,42 @@ void ExtractEventArgs(const XEventVisitor& event, StringTable& string_table, event.ForEachStat(for_each_stat); } +struct GraphKeyInfo { + std::string short_name; + int64_t device_id = 0; + int64_t iteration = 0; + + bool operator<(const GraphKeyInfo& other) const { + if (short_name != other.short_name) { + return short_name < other.short_name; + } + if (device_id != other.device_id) { + return device_id < other.device_id; + } + return iteration < other.iteration; + } +}; + +GraphKeyInfo ParseGraphKey(absl::string_view graph_key) { + GraphKeyInfo info; + static constexpr LazyRE2 kGraphKeyRe = { + R"(device_(\d+)_gid_([^\$]+)\$.*\^i(\d+).*)"}; + RE2::FullMatch(graph_key, *kGraphKeyRe, &info.device_id, &info.short_name, + &info.iteration); + return info; +} + +absl::string_view GetGraphKey(const Event& event, const XprofTrace& trace) { + for (const auto& arg : event.args) { + if (trace.string_table.Get(arg.key) == "graph_key") { + if (std::holds_alternative(arg.value)) { + return trace.string_table.Get(std::get(arg.value)); + } + } + } + return ""; +} + } // namespace XprofTrace XSpaceLoader::Load(const tsl::profiler::XSpace& space) { @@ -217,13 +254,12 @@ XprofTrace XSpaceLoader::Load(const tsl::profiler::XSpace& space) { return; } - std::string track_name(ExtractShortGraphKey(graph_key)); - // Find or create track in the BUFFER - Track*& track = track_ptr_map[track_name]; + Track*& track = track_ptr_map[graph_key]; if (track == nullptr) { + std::string track_name(ExtractShortGraphKey(graph_key)); std::list& fragments = raw_megascale_fragments[raw_device_id]; - fragments.push_back(Track{track_name, {}}); + fragments.push_back(Track{std::move(track_name), {}}); track = &fragments.back(); } @@ -279,6 +315,19 @@ XprofTrace XSpaceLoader::Load(const tsl::profiler::XSpace& space) { device_id_to_tpu_id[raw_device_ids[i]] = tpu_ids[i]; } + // --------------------------------------------------------------------------- + // 3.5. Sort Tracks based on Graph Key (short_name, device_id, iteration) + // --------------------------------------------------------------------------- + for (auto& [raw_id, tracks] : raw_megascale_fragments) { + tracks.sort([&](const Track& a, const Track& b) { + if (a.events.empty()) return true; + if (b.events.empty()) return false; + GraphKeyInfo info_a = ParseGraphKey(GetGraphKey(a.events[0], trace)); + GraphKeyInfo info_b = ParseGraphKey(GetGraphKey(b.events[0], trace)); + return info_a < info_b; + }); + } + // --------------------------------------------------------------------------- // 4. Finalize (Move Buffer to Trace) // ---------------------------------------------------------------------------