diff --git a/programming_examples/basic/vector_scalar_mul/vector_scalar_mul_jit.py b/programming_examples/basic/vector_scalar_mul/vector_scalar_mul_jit.py new file mode 100644 index 00000000000..05551ee3f2b --- /dev/null +++ b/programming_examples/basic/vector_scalar_mul/vector_scalar_mul_jit.py @@ -0,0 +1,219 @@ +# vector_scalar_mul/vector_scalar_mul_jit.py -*- Python -*- +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# (c) Copyright 2024-2025 Advanced Micro Devices, Inc. or its affiliates + +import argparse +import sys +import numpy as np +import aie.iron as iron +import os + +from aie.iron import ExternalFunction, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer +from aie.iron.device import NPU1Col1, NPU2Col1 +from aie.iron.controlflow import range_ +from aie.iron.dtype import str_to_dtype +import argparse +import sys +import numpy as np +import aie.iron as iron + +from aie.iron import ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer +from aie.iron.device import NPU1Col1, NPU2Col1 +from aie.iron.controlflow import range_ +from aie.iron import trace + + +@iron.jit(is_placed=False) +def vector_scalar_mul(input0, input1, output): + if input0.shape != output.shape: + raise ValueError( + f"Input and output shapes are not the same ({input0.shape} != {output.shape})." + ) + if len(np.shape(input0)) != 1: + raise ValueError("Function only supports vectors.") + + num_elements = np.size(input0) + + # Add size validation like in reference code + # Assert that input1 (factor) is size 4 bytes (1 integer) + if np.size(input1) != 1: + raise ValueError("2nd input buffer must be size 1 (1 integer).") + + # Assert output size matches input size + if output.numel() != input0.numel(): + raise ValueError("Output buffer size must match input buffer size.") + + num_sub_vectors = 4 + tile_size = num_elements // num_sub_vectors + + if num_elements % num_sub_vectors != 0: + raise ValueError( + f"Number of elements ({num_elements}) must be a multiple of {num_sub_vectors}." + ) + + if input0.dtype != output.dtype: + raise ValueError( + f"Input and output data types are not the same ({input0.dtype} != {output.dtype})." + ) + dtype = input0.dtype + + # Define tensor types - factor should be scalar_ty (np.int32), not tile_ty + tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] + tile_ty = np.ndarray[(tile_size,), np.dtype[dtype]] + scalar_ty = np.ndarray[(1,), np.dtype[np.int32]] + + # Create a handle to an externally-defined kernel + # Construct path to kernel source file + current_dir = os.path.dirname(__file__) + kernel_path = os.path.join(current_dir, "../../../aie_kernels/aie2", "scale.cc") + # Get the bit width directly from the dtype + bit_width = np.dtype(input0.dtype).itemsize * 8 + + # Use the same kernel function name as reference code + scale = ExternalFunction( + "vector_scalar_mul_vector", + source_file=kernel_path, + arg_types=[ + tile_ty, # input tensor + tile_ty, # output tensor + scalar_ty, # scalar factor + np.int32, # N + ], + compile_flags=[f"-DBIT_WIDTH={bit_width}"], + include_dirs=[os.path.join(current_dir, "../../../aie_kernels/aie2")], + ) + + # AIE-array data movement with object fifos + # Factor should be scalar_ty, not tensor_ty + of_in = ObjectFifo(tile_ty, name="in") + of_factor = ObjectFifo(scalar_ty, name="infactor") + of_out = ObjectFifo(tile_ty, name="out") + + # Define a task that will run on a compute tile + def core_body(of_in, of_factor, of_out, scale_fn): + # Acquire factor once outside the loop, like in reference code + elem_factor = of_factor.acquire(1) + + # Number of sub-vector "tile" iterations + for _ in range_(num_sub_vectors): + elem_in = of_in.acquire(1) + elem_out = of_out.acquire(1) + scale_fn(elem_in, elem_out, elem_factor, tile_size) + of_in.release(1) + of_out.release(1) + # Release factor once after the loop + of_factor.release(1) + + # Create a worker to run the task on a compute tile + # enable_trace = 1 if trace.get_trace_size() > 0 else 0 + worker = Worker( + core_body, + fn_args=[of_in.cons(), of_factor.cons(), of_out.prod(), scale], + trace=1 if trace.get_trace_size() > 0 else 0, + ) + + # Runtime operations to move data to/from the AIE-array + rt = Runtime() + + with rt.sequence(tensor_ty, scalar_ty, tensor_ty) as (A, F, C): + if trace.get_trace_size() > 0: + rt.enable_trace(trace.get_trace_size()) + rt.start(worker) + rt.fill(of_in.prod(), A) + rt.fill(of_factor.prod(), F) + rt.drain(of_out.cons(), C, wait=True) + + # Place program components (assign them resources on the device) and generate an MLIR module + return Program(iron.get_current_device(), rt).resolve_program(SequentialPlacer()) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose output" + ) + parser.add_argument( + "-n", + "--num-elements", + type=int, + default=1024, + help="Number of elements (default: 1024, must be multiple of 128 and >= 1024)", + ) + parser.add_argument( + "-t", + "--trace-size", + type=int, + default=1024, + help="Trace buffer size (0 = no tracing, default: 0)", + ) + parser.add_argument( + "-z", + "--data_type", + choices=["i16", "i32"], + default="i16", + help="Data type (default: i16)", + ) + args = parser.parse_args() + + # Buffer size validation like reference code + if args.num_elements % 128 != 0 or args.num_elements < 1024: + print( + "Number of elements must be a multiple of 128 (so len is multiple of 64) and greater than or equal to 1024 (so len >= 512)" + ) + raise ValueError + + # Construct input random tensors and an output zeroed tensor + # The tensors are in memory accessible to the NPU + datatype = str_to_dtype(args.data_type) + input0 = iron.randint(0, 100, (args.num_elements,), dtype=datatype, device="npu") + scalar = iron.randint(0, 100, (1,), dtype=np.int32, device="npu") + output = iron.zeros_like(input0) + + # Enable tracing if requested + if args.trace_size > 0: + trace.set_trace_size(args.trace_size) + trace.start_trace() + + # JIT-compile the kernel then launches the kernel with the given arguments + vector_scalar_mul(input0, scalar, output) + + # Stop tracing and save results if tracing was enabled + if args.trace_size > 0: + trace_filename = f"trace_output_{args.num_elements}_{args.data_type}.json" + trace.stop_trace(trace_filename) + print(f"Tracing completed and saved to {trace_filename}") + + # Check the correctness of the result - use scalar multiplication + expected = input0.numpy() * scalar.numpy()[0] + actual = output.numpy() + e = np.equal(expected, actual) + errors = np.size(e) - np.count_nonzero(e) + + # Optionally, print the results + if args.verbose: + print(f"{'input0':>4} * {'factor':>4} = {'output':>4}") + print("-" * 34) + count = input0.numel() + factor = scalar.numpy()[0] + for idx, (a, c) in enumerate(zip(input0[:count], output[:count])): + print(f"{idx:2}: {a:4} * {factor:4} = {c:4}") + + # If the result is correct, exit with a success code. + # Otherwise, exit with a failure code + if not errors: + print("\nPASS!\n") + sys.exit(0) + else: + print("\nError count: ", errors) + print("\nFailed.\n") + sys.exit(-1) + + +if __name__ == "__main__": + main() diff --git a/python/iron/__init__.py b/python/iron/__init__.py index 228f4bcb0a8..4e373776e3e 100644 --- a/python/iron/__init__.py +++ b/python/iron/__init__.py @@ -24,5 +24,7 @@ arange, zeros_like, ) + + from . import trace except ImportError: pass # silently ignore if pyxrt or .jit can't be imported diff --git a/python/iron/jit.py b/python/iron/jit.py index 5bf38a938d3..ba9a76fd72d 100644 --- a/python/iron/jit.py +++ b/python/iron/jit.py @@ -21,6 +21,13 @@ from .compile import compile_mlir_module from .config import get_current_device from aie.dialects.aie import AIEDevice +from .tensor import zeros +from .trace import ( + _get_trace_active, + _get_trace_tensor, + _get_dummy_tensor, + set_mlir_module, +) # The `iron.jit` decorator below caches compiled kenrels inside the `IRON_CACHE_HOME` directory. @@ -142,6 +149,26 @@ def __call__(self, *args): ) kernel_args.append(tensor.buffer_object()) + if _get_trace_active(): + # We always put the trace tensor at the 5th argument to match backend tracing logic + # So we only enable tracing if we have at most 4 user arguments + trace_tensor = _get_trace_tensor() + if trace_tensor is None: + raise RuntimeError("Tracing active but no trace tensor found") + + if len(kernel_args) >= 5: + raise ValueError( + f"Tracing can only be done for kernels with 4 or fewer arguments. Got {len(kernel_args)} arguments." + ) + + # Pad with dummy tensors if needed and add them to kernel_args + while len(kernel_args) < 4: + dummy_tensor = _get_dummy_tensor() + kernel_args.append(dummy_tensor.buffer_object()) + + # Add trace tensor as the 5th argument + kernel_args.append(trace_tensor.buffer_object()) + h = self.__kernel(opcode, self.__insts_buffer_bo, self.__n_insts, *kernel_args) r = h.wait() if r != xrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: @@ -282,12 +309,16 @@ def decorator(*args, **kwargs): xclbin_path=xclbin_path, work_dir=kernel_dir, ) + except Exception as e: # Clean up cache directory on any compilation failure to avoid any corrupted objects in the cache if os.path.exists(kernel_dir): shutil.rmtree(kernel_dir) raise e + # Set the MLIR module globally for tracing to use + set_mlir_module(str(mlir_module)) + kernel_name = "MLIR_AIE" try: kernel = NPUKernel(xclbin_path, inst_path, kernel_name=kernel_name) diff --git a/python/iron/runtime/runtime.py b/python/iron/runtime/runtime.py index 2bb44fbcadb..f19b9c511fc 100644 --- a/python/iron/runtime/runtime.py +++ b/python/iron/runtime/runtime.py @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # -# (c) Copyright 2024 Advanced Micro Devices, Inc. +# (c) Copyright 2024-2025 Advanced Micro Devices, Inc. from __future__ import annotations from collections import defaultdict @@ -35,6 +35,7 @@ InlineOpRuntimeTask, FinishTaskGroupTask, ) +from .. import trace class Runtime(Resolvable): @@ -73,6 +74,11 @@ def sequence(self, *input_types: type[np.ndarray]): """ try: self._rt_data = list(map(RuntimeData, input_types)) + + # Auto-enable tracing if tracing is active + if trace._get_trace_active() and self._trace_size is None: + self.enable_trace(trace_size=trace.get_trace_size()) + if len(self._rt_data) == 1: yield self._rt_data[0] else: diff --git a/python/iron/trace.py b/python/iron/trace.py new file mode 100644 index 00000000000..f04505e865b --- /dev/null +++ b/python/iron/trace.py @@ -0,0 +1,965 @@ +#!/usr/bin/env python3 +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# (c) Copyright 2025 AMD Inc. + +import json +import argparse +import sys +import re +import numpy as np +from .tensor import zeros +from .trace_events_enum import CoreEvent, MemEvent, ShimTileEvent, MemTileEvent + +# Add missing imports from parse_trace.py +from aie.extras.util import find_ops +from aie.ir import Context, Module, Location +import aie.dialects.aie as aiedialect +import aie.dialects.aiex as aiexdialect + +# Number of different trace types, currently 4 +# core: pkt type 0 +# mem: pkt type 1 +# shim: pkt type 2 +# memtile: pkt type 3 +NumTraceTypes = 4 +NUM_EVENTS = 8 # number of events we can view per trace + +# Global variables +_trace_active = False +_trace_tensor = None +DEBUG = False +_trace_size = 4096 # Default trace size (4k) +_trace_buffer = [] +_dummy_tensor = None # Reusable dummy tensor +_mlir_module = None # Global to store MLIR from JIT + + +def _prepare_trace_input(trace_input): + """ + Prepare trace input by detecting if it's a file path or data string. + + Args: + trace_input: Either a file path string or data string + + Returns: + List of trace data strings + """ + if isinstance(trace_input, str): + # Assume it's a file path unless it contains newlines (then it's data) + if "\n" in trace_input: + # It's trace data as a string, split by newlines + return trace_input.split("\n") + else: + # Assume it's a file path + try: + with open(trace_input, "r") as f: + return f.read().split("\n") + except Exception as e: + raise RuntimeError(f"Could not open trace file {trace_input}: {e}") + else: + return trace_input + + +def parse_and_save_trace( + trace_input, mlir_input, colshift=None, debug=False, output_file=None +): + """ + Parse trace data and optionally save to file. + + Args: + trace_input: Either a string (file path) or list of strings (trace data) + mlir_input: Either a string (file path) or MLIR module string + colshift: Optional column shift adjustment + debug: Whether to enable debug mode + output_file: Optional file path or file object for output + + Returns: + None (output is written to file if output_file is provided) + """ + global DEBUG + DEBUG = debug + + # Handle trace input - can be file path, data, or numpy array + if hasattr(trace_input, "dtype"): # numpy array + trace_pkts = [f"{val:08x}" for val in trace_input] + else: + trace_pkts = _prepare_trace_input(trace_input) + + # Handle MLIR input - can be file path or data + if isinstance(mlir_input, str) and not mlir_input.strip().startswith("module"): + # Assume it's a file path + try: + with open(mlir_input, "r") as f: + mlir_module_str = f.read() + except Exception as e: + raise RuntimeError(f"Could not open MLIR file {mlir_input}: {e}") + else: + mlir_module_str = mlir_input + + # Parse MLIR trace events + try: + pid_events = parse_mlir_trace_events(mlir_module_str, colshift) + except Exception as e: + raise RuntimeError(f"Could not parse MLIR module: {e}") + + # Validate trace data + if not check_for_valid_trace( + trace_input if isinstance(trace_input, str) else "input_data", + trace_pkts, + output_file, + ): + raise RuntimeError("Invalid trace data") + + # Process trace data + trimmed_trace_pkts = trim_trace_pkts(trace_pkts) + if DEBUG and output_file: + lines_removed = len(trace_pkts) - len(trimmed_trace_pkts) + print(f"DEBUG: trimmed {lines_removed} lines", file=output_file) + + trace_pkts_sorted = trace_pkts_de_interleave(trimmed_trace_pkts) + byte_streams = convert_to_byte_stream(trace_pkts_sorted) + commands_0 = convert_to_commands(byte_streams, False) + + # Align column start index if no colshift provided + if colshift is None: + pid_events = align_column_start_index(pid_events, commands_0) + + # Generate trace events + trace_events = [] + setup_trace_metadata(trace_events, pid_events) + convert_commands_to_json(trace_events, commands_0, pid_events, output_file) + + # If output_file is provided, write JSON directly + if output_file: + # output_file is always a string filename, so open it and write + with open(output_file, "w") as f: + print( + json.dumps(trace_events).replace("'", '"').replace(", {", ",\n{"), + file=f, + ) + + +def lookup_event_name_by_type(trace_type, code): + """Look up event name by trace type and code.""" + event = "" + events_enum = None + if trace_type == 0: # Core traces + events_enum = CoreEvent + elif trace_type == 1: # Mem traces + events_enum = MemEvent + elif trace_type == 2: # Shim traces + events_enum = ShimTileEvent + elif trace_type == 3: # MemTile traces + events_enum = MemTileEvent + if events_enum is not None and code in set(x.value for x in events_enum): + event = events_enum(code).name + else: + event = "Unknown" + return event + + +def lookupEventNameInStr(event, pid, pid_events): + """ + TODO Expand to other pid for multiple cores? even/odd + For now, we assume a single trace event and key based on that + in the future, the pid will be used to match the right events + """ + return lookup_event_name_by_code(pid_events[0][int(event)]) + + +def lookup_event_name_by_code(code, pid_events): + """ + TODO Expand to other pid for multiple cores? even/odd + For now, we assume a single trace event and key based on that + in the future, the pid will be used to match the right events + """ + return lookup_event_name_by_type(0, pid_events[0][int(code)]) + + +def make_event_lists(commands): + """Create event lists from commands.""" + events = {} + ts = 0 + for i, command in enumerate(commands): + if command["type"] == "Start": + ts = command["timer_value"] + if command["type"] == "Event_Sync": + ts += 0x3FFFF # Typo in spec + if "Single" in command["type"]: + ts += command["cycles"] + if command["event"] in events: + events[command["event"]].append(ts) + else: + events[command["event"]] = [ts] + return events + + +def flatten_repeat_command(commands): + """Flatten repeat commands for processing.""" + prev = 0 + flat_commands = list() + for c in commands: + if c["type"] == "Repeat0" or c["type"] == "Repeat1": + for i in range(int(c["repeats"])): + flat_commands.append(prev) + else: + flat_commands.append(c) + prev = c + return flat_commands + + +def deactivate_events( + multiples, + active_events, + timer, + cycles, + pid, + trace_type, + loc, + pid_events, + trace_events, +): + """Deactivate events based on conditions.""" + for k in active_events.keys(): # an active event + if cycles > 0 or (cycles == 0 and not k in multiples): + if active_events[k] > 0: + trace_event = { + "name": lookup_event_name_by_type( + trace_type, pid_events[trace_type][loc][k] + ) + } + trace_event["ts"] = timer + trace_event["ph"] = "E" + trace_event["pid"] = pid + trace_event["tid"] = k + trace_event["args"] = {} + trace_events.append(trace_event) + active_events[k] = 0 + + +def activate_event(event, tt, loc, timer, pid, active_events, pid_events, trace_events): + """Activate an event.""" + try: + if active_events[event] == 0: + trace_event = { + "name": lookup_event_name_by_type(tt, pid_events[tt][loc][event]) + } + trace_event["ts"] = timer + trace_event["ph"] = "B" + trace_event["pid"] = pid + trace_event["tid"] = event + trace_event["args"] = {} + trace_events.append(trace_event) + active_events[event] = 1 + except KeyError: + pass + + +def process_name_metadata(trace_events, pid, trace_type, loc): + """Process name metadata for trace events.""" + trace_event = {"name": "process_name"} + trace_event["ph"] = "M" + trace_event["pid"] = pid + trace_event["args"] = {} + if trace_type == 0: + trace_event["args"]["name"] = "core_trace for tile" + str(loc) + elif trace_type == 1: + trace_event["args"]["name"] = "mem_trace for tile" + str(loc) + elif trace_type == 2: + trace_event["args"]["name"] = "shim_trace for tile" + str(loc) + elif trace_type == 3: + trace_event["args"]["name"] = "memtile_trace for tile" + str(loc) + trace_events.append(trace_event) + + +def thread_name_metadata(trace_events, trace_type, loc, pid, tid, pid_events): + """Process thread name metadata.""" + trace_event = {"name": "thread_name"} + trace_event["ph"] = "M" + trace_event["pid"] = pid + trace_event["tid"] = tid + trace_event["args"] = {} + trace_event["args"]["name"] = lookup_event_name_by_type( + trace_type, pid_events[trace_type][loc][tid] + ) + trace_events.append(trace_event) + + +def start_tracing(size=4096): + """ + Start tracing functionality with optional size parameter. + + Args: + size (int, optional): The maximum size of the trace buffer. Defaults to 4096 (4k). + """ + global _trace_active, _trace_buffer, _trace_tensor, _trace_size, _dummy_tensor + + # Set trace size + if size > 0: + _trace_size = size + else: + raise RuntimeError("Trace size must be positive") + + if not _trace_active: + _trace_active = True + _trace_buffer = [] + + # Create the trace tensor when starting trace + _trace_tensor = zeros(_trace_size, dtype=np.uint32) + # Create a reusable dummy tensor + _dummy_tensor = zeros(1, dtype=np.uint32) + + +def stop_tracing(output_file="trace.json"): + """Stop tracing functionality and optionally save to file.""" + if not output_file.endswith(".json"): + raise RuntimeError("Only JSON output files are supported") + + global _trace_active, _trace_tensor, _mlir_module + if _trace_active: + _trace_active = False + # Read trace tensor data and save to file + if _trace_tensor is not None: + # Check if MLIR module is available + if _mlir_module is None: + raise RuntimeError( + "No MLIR module available for tracing. Did you call set_mlir_module()?" + ) + + # Parse trace data using the library function + parse_and_save_trace( + trace_input=_trace_tensor.numpy(), + mlir_input=_mlir_module, + colshift=None, + debug=False, + output_file=output_file, + ) + else: + raise RuntimeError("No trace tensor found") + else: + raise RuntimeError("Tracing was not active") + + +def set_mlir_module(mlir_module): + """Set the MLIR module for tracing to use.""" + global _mlir_module + _mlir_module = mlir_module + + +# Getter functions +def get_trace_size(): + """Get the current trace buffer size limit.""" + return _trace_size + + +def _get_trace_active(): + """Get the current tracing status.""" + return _trace_active + + +def _get_trace_tensor(): + """Get the current trace tensor.""" + return _trace_tensor + + +def _get_dummy_tensor(): + """Get the reusable dummy tensor.""" + return _dummy_tensor + + +# Check for valid trace packets data +# 1) if only 1 trace packet +# 2) if first trace packet is all 0's +def check_for_valid_trace(filename, trace_pkts, of): + if DEBUG: + print("len(trace_pkts): ", str(len(trace_pkts)), file=of) + print("trace_pkts[0]:", trace_pkts[0], file=of) + if len(trace_pkts) < 2 or trace_pkts[0] == "00000000": + print( + "[ERROR] Empty trace file. Valid trace was not written to", + filename, + file=sys.stderr, + ) + print( + "See https://github.com/Xilinx/mlir-aie/tree/main/programming_guide/section-4/section-4b#Additional-Debug-Hints for additional trace debug tips.", + file=sys.stderr, + ) + return False + return True + + +def trim_trace_pkts(trace_pkts): + for i in range(len(trace_pkts)): + if trace_pkts[i] == "fefefefe" or trace_pkts[i] == "FEFEFEFE": + if i + 2 < len(trace_pkts): + if trace_pkts[i + 1] == "00000000" and trace_pkts[i + 2] == "00000000": + return trace_pkts[0 : i + 1] + return trace_pkts + + +def check_odd_word_parity(word): + val = 0 + for i in range(32): + val = val ^ ((word >> i) & 0x1) + return val == 1 + + +def parse_pkt_hdr_in_stream(word): + hdr = dict() + w = int(word) + hdr["valid"] = check_odd_word_parity(w) + # TODO can we assume non used fields must be 0 to rule out other data packets? + # what about bit[5:10]? + if (((w >> 5) & 0x7F) != 0) or (((w >> 19) & 0x1) != 0) or (((w >> 28) & 0x7) != 0): + hdr["valid"] = False + else: + # TODO Do we need to check for valid row/col for given device? + hdr["col"] = (w >> 21) & 0x7F + hdr["row"] = (w >> 16) & 0x1F + hdr["type"] = (w >> 12) & 0x3 + hdr["id"] = w & 0x1F + return hdr + + +# Sorts list of trace packets into a list indexed by trace type (core, mem, shim, memtile) +# and the value is dictionary tile location (key) and trace packets (value) +# +# trace_pkts_sorted: list (idx = types of traces, currently 4, value = stream_dict) +# stream_dict: dict (key = row,col, value = list of word streams) +def trace_pkts_de_interleave(word_stream): + trace_pkts_sorted = list() + for t in range(NumTraceTypes): + trace_pkts_sorted.append(dict()) + + curr_pkt_type = 0 + curr_loc = "" + curr_vld = False # only used in the beginning + + for i in range(len(word_stream)): + if word_stream[i] == "": + break # TODO Assumes a blank line is the last line + if (i % 8) == 0: + pkt_hdr = parse_pkt_hdr_in_stream(int(word_stream[i], 16)) + if pkt_hdr["valid"]: + curr_loc = str(pkt_hdr["row"]) + "," + str(pkt_hdr["col"]) + valid_type_found = False + for tt in range(NumTraceTypes): + if pkt_hdr["type"] == tt: + curr_pkt_type = tt + if trace_pkts_sorted[tt].get(curr_loc) == None: + trace_pkts_sorted[tt][curr_loc] = list() + valid_type_found = True + if not valid_type_found: + sys.exit("Error: Invalid packet type") + curr_vld = True + else: + if curr_vld: # ignores first 8 chunks of data is pkt hdr was invalid + trace_pkts_sorted[curr_pkt_type][curr_loc].append(word_stream[i]) + return trace_pkts_sorted + + +# Convert trace packets into byte streams +def convert_to_byte_stream(toks_list): + byte_stream_list = list() + for l in toks_list: + byte_stream_dict = dict() + for loc, stream in l.items(): + byte_stream_dict[loc] = list() + f = ["", "a5a5a5a5"] + toks = [t for t in stream if not t in f] + events = [int(t, 16) for t in toks] + for event in events: + for top in range(4): + byte = 3 - top + opcode = event >> (byte * 8) & 0xFF + byte_stream_dict[loc].append(opcode) + byte_stream_list.append(byte_stream_dict) + return byte_stream_list + + +# Convert byte streams to equivalent packet commands +def convert_to_commands(byte_stream_list, zero=True): + commands = list() + for t in range(NumTraceTypes): + commands.append(dict()) + + for t in range(NumTraceTypes): + for key, byte_stream in byte_stream_list[t].items(): + cursor = 0 + commands[t][key] = list() + try: + while True: + if (byte_stream[cursor] & 0b11111011) == 0b11110000: + com = {"type": "Start", "timer_value": 0} + if not zero: + for i in range(7): + com["timer_value"] += (byte_stream[cursor + i + 1]) * ( + 256 ** (6 - i) + ) + commands[t][key].append(com) + cursor = cursor + 8 + if (byte_stream[cursor] & 0b11111100) == 0b11011100: + # We don't care about these + cursor = cursor + 4 + if (byte_stream[cursor] & 0b10000000) == 0b00000000: + com = {"type": "Single0"} + com["event"] = (byte_stream[cursor]) >> 4 & 0b111 + com["cycles"] = (byte_stream[cursor]) & 0b1111 + commands[t][key].append(com) + cursor = cursor + 1 + if (byte_stream[cursor] & 0b11100000) == 0b10000000: + com = {"type": "Single1"} + com["event"] = (byte_stream[cursor]) >> 2 & 0b111 + com["cycles"] = ((byte_stream[cursor]) & 0b11) * 256 + com["cycles"] += byte_stream[cursor + 1] + commands[t][key].append(com) + cursor = cursor + 2 + if (byte_stream[cursor] & 0b11100000) == 0b10100000: + com = {"type": "Single2"} + com["event"] = (byte_stream[cursor]) >> 2 & 0b111 + com["cycles"] = ((byte_stream[cursor]) & 0b11) * 256 * 256 + com["cycles"] += byte_stream[cursor + 1] * 256 + com["cycles"] += byte_stream[cursor + 2] + commands[t][key].append(com) + cursor = cursor + 3 + if (byte_stream[cursor] & 0b11110000) == 0b11000000: + com = {"type": "Multiple0"} + com["cycles"] = byte_stream[cursor + 1] & 0b1111 + events = (byte_stream[cursor] & 0b1111) << 4 + events = events + (byte_stream[cursor + 1] >> 4) + for i in range(0, 8): + e = (events >> i) & 0b1 + if e: + com["event" + str(i)] = i + commands[t][key].append(com) + cursor = cursor + 2 + if (byte_stream[cursor] & 0b11111100) == 0b11010000: + com = {"type": "Multiple1"} + cycles = (byte_stream[cursor + 1] & 0b11) << 8 + com["cycles"] = cycles + (byte_stream[cursor + 2]) + events = (byte_stream[cursor] & 0b11) << 6 + events = events + (byte_stream[cursor + 1] >> 2) + for i in range(0, 8): + e = (events >> i) & 0b1 + if e: + com["event" + str(i)] = i + commands[t][key].append(com) + cursor = cursor + 3 + if (byte_stream[cursor] & 0b11111100) == 0b11010100: + com = {"type": "Multiple2"} + cycles = (byte_stream[cursor + 1] & 0b11) << 16 + cycles = cycles + ((byte_stream[cursor + 2]) << 8) + com["cycles"] = cycles + (byte_stream[cursor + 3]) + events = (byte_stream[cursor] & 0b11) << 6 + events = events + (byte_stream[cursor + 1] >> 2) + for i in range(0, 8): + e = (events >> i) & 0b1 + if e: + com["event" + str(i)] = i + commands[t][key].append(com) + cursor = cursor + 4 + if (byte_stream[cursor] & 0b11110000) == 0b11100000: + com = {"type": "Repeat0"} + com["repeats"] = (byte_stream[cursor]) & 0b1111 + commands[t][key].append(com) + cursor = cursor + 1 + if (byte_stream[cursor] & 0b11111100) == 0b11011000: + com = {"type": "Repeat1"} + com["repeats"] = ((byte_stream[cursor]) & 0b11) * 256 + com["repeats"] += byte_stream[cursor + 1] + commands[t][key].append(com) + cursor = cursor + 2 + if (byte_stream[cursor] & 0b11111111) == 0b11111110: + # No one likes you filler, get out of here + cursor = cursor + 1 + if (byte_stream[cursor] & 0b11111111) == 0b11111111: + com = {"type": "Event_Sync"} + commands[t][key].append(com) + cursor = cursor + 1 + except IndexError: + pass + + return commands + + +def parse_mlir_trace_events(mlir_module_str, colshift=None): + """ + Parse MLIR module to extract trace event configurations. + + This searches for npu.write32 and categorizes them based on address and row. + It's probably not the best way to do it but it's the initial implementation. + memtile and core/shim tiles have different addresses. For now, we distinguish + between core and shim tile by row=0 + """ + pid_events = list() + for t in range(NumTraceTypes): + pid_events.append(dict()) + + with Context(), Location.unknown(): + module = Module.parse(mlir_module_str) + + write32s = find_ops( + module.operation, + lambda o: isinstance(o.operation.opview, aiexdialect.NpuWrite32Op), + ) + device = find_ops( + module.operation, + lambda o: isinstance(o.operation.opview, aiedialect.DeviceOp), + ) + device = aiedialect.AIEDevice(int(device[0].device)) + target_model = aiedialect.get_target_model(device) + + for write32 in write32s: + address = None + row = None + col = None + value = None + if write32.address: + address = write32.address.value + if write32.row: + row = write32.row.value + if write32.column: + col = write32.column.value + if write32.value: + value = write32.value.value + + if row is None and col is None: + row = (address >> target_model.get_row_shift()) & 0x1F + col = (address >> target_model.get_column_shift()) & 0x1F + address = address & 0xFFFFF # 20 bits address + + if None in [row, col, address, value]: + print(f"[ERROR] Could not decode write32 op '{write32}'") + sys.exit(1) + + # Adjust column based on colshift + if colshift is not None: + col = col + colshift + key = str(row) + "," + str(col) + + # core event 0 + if address == 0x340E0: # 213216, match ignoring case + if row == 0: # shim + if pid_events[2].get(key) == None: + pid_events[2][key] = [0] * 8 + pid_events[2][key][0] = value & 0xFF + pid_events[2][key][1] = (value >> 8) & 0xFF + pid_events[2][key][2] = (value >> 16) & 0xFF + pid_events[2][key][3] = (value >> 24) & 0xFF + else: # core + if pid_events[0].get(key) == None: + pid_events[0][key] = [0] * 8 + pid_events[0][key][0] = value & 0xFF + pid_events[0][key][1] = (value >> 8) & 0xFF + pid_events[0][key][2] = (value >> 16) & 0xFF + pid_events[0][key][3] = (value >> 24) & 0xFF + # core event 1 + elif address == 0x340E4: # 213220, match ignoring case + if row == 0: # shim + if pid_events[2].get(key) == None: + pid_events[2][key] = [0] * 8 + pid_events[2][key][4] = value & 0xFF + pid_events[2][key][5] = (value >> 8) & 0xFF + pid_events[2][key][6] = (value >> 16) & 0xFF + pid_events[2][key][7] = (value >> 24) & 0xFF + else: # core + if pid_events[0].get(key) == None: + pid_events[0][key] = [0] * 8 + pid_events[0][key][4] = value & 0xFF + pid_events[0][key][5] = (value >> 8) & 0xFF + pid_events[0][key][6] = (value >> 16) & 0xFF + pid_events[0][key][7] = (value >> 24) & 0xFF + # mem event 0 + elif address == 0x140E0: # 82144 + if pid_events[1].get(key) == None: + pid_events[1][key] = [0] * 8 + pid_events[1][key][0] = value & 0xFF + pid_events[1][key][1] = (value >> 8) & 0xFF + pid_events[1][key][2] = (value >> 16) & 0xFF + pid_events[1][key][3] = (value >> 24) & 0xFF + # mem event 1 + elif address == 0x140E4: # 82148 + if pid_events[1].get(key) == None: + pid_events[1][key] = [0] * 8 + pid_events[1][key][4] = value & 0xFF + pid_events[1][key][5] = (value >> 8) & 0xFF + pid_events[1][key][6] = (value >> 16) & 0xFF + pid_events[1][key][7] = (value >> 24) & 0xFF + # memtile event 0 + elif address == 0x940E0: # 606432 + if pid_events[3].get(key) == None: + pid_events[3][key] = [0] * 8 + pid_events[3][key][0] = value & 0xFF + pid_events[3][key][1] = (value >> 8) & 0xFF + pid_events[3][key][2] = (value >> 16) & 0xFF + pid_events[3][key][3] = (value >> 24) & 0xFF + # memtile event 1 + elif address == 0x940E4: # 606436 + if pid_events[3].get(key) == None: + pid_events[3][key] = [0] * 8 + pid_events[3][key][4] = value & 0xFF + pid_events[3][key][5] = (value >> 8) & 0xFF + pid_events[3][key][6] = (value >> 16) & 0xFF + pid_events[3][key][7] = (value >> 24) & 0xFF + + return pid_events + + +def align_column_start_index(events, commands): + """ + Attempt to align the starting column of trace in the design (from 'events') + with the start first column observed in the trace ('commands'). This is needed + because the runtime/firmware can start the design on any valid column + """ + # find min column of commands + min_commands_col = float("inf") + for t in range(NumTraceTypes): + for loc in commands[t]: + col = int(loc.split(",")[1]) + if col < min_commands_col: + min_commands_col = col + + # find min column of events + min_events_col = float("inf") + for t in range(NumTraceTypes): + for loc in events[t]: + col = int(loc.split(",")[1]) + if col < min_events_col: + min_events_col = col + + # The shift is the difference between the expected and observed leftmost + # column for which trace was enabled (in 'events') + colshift = min_events_col - min_commands_col + + # Shift all event keys by colshift + new_events = [] + for t in range(NumTraceTypes): + updated = {} + for loc, l in events[t].items(): + row, col = map(int, loc.split(",")) + new_col = col - colshift + new_key = f"{row},{new_col}" + updated[new_key] = l + new_events.append(updated) + return new_events + + +def setup_trace_metadata(trace_events, pid_events): + """ + This sets up the trace metadata and also assigned the unique pid that's referred + eleswhere for each process (combination of tile(row,col) and trace type). + NOTE: This assume the pid_events has already be analyzed and populated. + """ + pid = 0 + for t in range(NumTraceTypes): + for loc in pid_events[t]: # return loc + process_name_metadata(trace_events, pid, t, loc) + for e in range(8): + thread_name_metadata(trace_events, t, loc, pid, e, pid_events) + pid_events[t][loc].append(pid) # assign unique pid + pid = pid + 1 + + +def convert_commands_to_json(trace_events, commands, pid_events, output_file): + """ + Convert commands to JSON format for trace events. + + commands: list (idx = trace type, value = byte_stream_dict) + byte_stream_dict: dict (key = row,col, value = list of commands) + """ + # byte_stream_dict for each trace type. + for [tt, byte_stream_dict] in enumerate(commands): # tt = trace type + + for loc, command in byte_stream_dict.items(): # row,col with list of commands + timer = 0 # TODO Some way to set this or sync this between trace types and row,col + # timer on each execution is the time for the last execution + # so we by default will increment it by 1 for each event + if DEBUG: + print( + "tt: " + + str(tt) + + ", loc: " + + str(loc) + + ", NUM_EVENTS: " + + str(NUM_EVENTS), + file=output_file, + ) + + if loc in pid_events[tt]: + pid = pid_events[tt][loc][NUM_EVENTS] + else: + print( + "[ERROR] tile in", + loc, + "not found in trace packet data file (e.g trace.txt).", + file=sys.stderr, + ) + tiles = [] + for tt_tmp in range(len(commands)): + for keys in pid_events[tt_tmp]: + tiles.append(keys) + print("Defined tiles in design are at:", tiles, file=sys.stderr) + print( + "Consider changing --colshift value if you think this is an error.", + file=sys.stderr, + ) + sys.exit(1) + + active_events = dict() + for i in range(8): # 8 max events at a time + active_events[i] = 0 + + if DEBUG: + print("num commands:", len(command), file=output_file) + for c in command: + t = c["type"] + if "Single" in t: + event = c["event"] + cycles = int(c["cycles"]) + timer = timer + 1 + multiple_list = list() + multiple_list.append(c["event"]) + deactivate_events( + multiple_list, + active_events, + timer, + cycles, + pid, + tt, + loc, + pid_events, + trace_events, + ) + timer = timer + cycles + activate_event( + event, + tt, + loc, + timer, + pid, + active_events, + pid_events, + trace_events, + ) + + elif "Multiple" in t: + cycles = int(c["cycles"]) + timer = timer + 1 + multiple_list = list() + for k in c.keys(): + if "event" in k: + multiple_list.append(c[k]) + deactivate_events( + multiple_list, + active_events, + timer, + cycles, + pid, + tt, + loc, + pid_events, + trace_events, + ) + timer = timer + cycles + + for k in c.keys(): + if "event" in k: + activate_event( + c[k], + tt, + loc, + timer, + pid, + active_events, + pid_events, + trace_events, + ) + + elif "Repeat" in t: + if ( + cycles == 0 + ): # last event has cycles == 0 so we just extend it by the repaet count + timer = timer + int(c["repeats"]) + else: + for repeats_cnt in range(int(c["repeats"])): + timer = timer + 1 + deactivate_events( + multiple_list, + active_events, + timer, + cycles, + pid, + tt, + loc, + pid_events, + trace_events, + ) + timer = timer + cycles + if len(multiple_list) > 1: + for k in c.keys(): + if "event" in k: + activate_event( + c[k], + tt, + loc, + timer, + pid, + active_events, + pid_events, + trace_events, + ) + else: + activate_event( + event, + tt, + loc, + timer, + pid, + active_events, + pid_events, + trace_events, + ) + + +# Main function for command-line usage +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", help="Input trace file", required=True) + parser.add_argument("--mlir", help="mlir source file", required=True) + parser.add_argument( + "--colshift", help="column shift adjustment to source mlir", required=False + ) + parser.add_argument("--output", help="Output json file", required=True) + parser.add_argument("--debug", help="debug mode", required=False) + return parser.parse_args(sys.argv[1:]) + + +if __name__ == "__main__": + # Command-line execution + opts = parse_args() + + DEBUG = opts.debug + if DEBUG: + print("Debug mode enable\n") + + # set colshift based on optional argument + colshift = int(opts.colshift) if opts.colshift else None + + # Use parse_and_save_trace for consistent processing + try: + parse_and_save_trace( + trace_input=opts.input, + mlir_input=opts.mlir, + colshift=colshift, + debug=DEBUG, + output_file=opts.output, # Pass output file path directly + ) + + print(f"Trace data successfully parsed and saved to {opts.output}") + + except Exception as e: + print(f"ERROR: Failed to parse trace data: {e}", file=sys.stderr) + exit(1) diff --git a/python/iron/trace_events_enum.py b/python/iron/trace_events_enum.py new file mode 100644 index 00000000000..e2c1e23d9b6 --- /dev/null +++ b/python/iron/trace_events_enum.py @@ -0,0 +1,561 @@ +# Enumeration of AIE2 trace events +# Automatically generated from utils/generate_events_enum.py + +from enum import Enum + + +class CoreEvent(Enum): + NONE = 0 + TRUE = 1 + GROUP_0 = 2 + TIMER_SYNC = 3 + TIMER_VALUE_REACHED = 4 + PERF_CNT_0 = 5 + PERF_CNT_1 = 6 + PERF_CNT_2 = 7 + PERF_CNT_3 = 8 + COMBO_EVENT_0 = 9 + COMBO_EVENT_1 = 10 + COMBO_EVENT_2 = 11 + COMBO_EVENT_3 = 12 + EDGE_DETECTION_EVENT_0 = 13 + EDGE_DETECTION_EVENT_1 = 14 + GROUP_PC_EVENT = 15 + PC_0 = 16 + PC_1 = 17 + PC_2 = 18 + PC_3 = 19 + PC_RANGE_0_1 = 20 + PC_RANGE_2_3 = 21 + GROUP_STALL = 22 + MEMORY_STALL = 23 + STREAM_STALL = 24 + CASCADE_STALL = 25 + LOCK_STALL = 26 + DEBUG_HALTED = 27 + ACTIVE = 28 + DISABLED = 29 + ECC_ERROR_STALL = 30 + ECC_SCRUBBING_STALL = 31 + GROUP_PROGRAM_FLOW = 32 + INSTR_EVENT_0 = 33 + INSTR_EVENT_1 = 34 + INSTR_CALL = 35 + INSTR_RETURN = 36 + INSTR_VECTOR = 37 + INSTR_LOAD = 38 + INSTR_STORE = 39 + INSTR_STREAM_GET = 40 + INSTR_STREAM_PUT = 41 + INSTR_CASCADE_GET = 42 + INSTR_CASCADE_PUT = 43 + INSTR_LOCK_ACQUIRE_REQ = 44 + INSTR_LOCK_RELEASE_REQ = 45 + GROUP_ERRORS_0 = 46 + GROUP_ERRORS_1 = 47 + SRS_OVERFLOW = 48 + UPS_OVERFLOW = 49 + FP_HUGE = 50 + INT_FP_0 = 51 + FP_INVALID = 52 + FP_INF = 53 + rsvd_54 = 54 + PM_REG_ACCESS_FAILURE = 55 + STREAM_PKT_PARITY_ERROR = 56 + CONTROL_PKT_ERROR = 57 + AXI_MM_SLAVE_ERROR = 58 + INSTR_DECOMPRSN_ERROR = 59 + DM_ADDRESS_OUT_OF_RANGE = 60 + PM_ECC_ERROR_SCRUB_CORRECTED = 61 + PM_ECC_ERROR_SCRUB_2BIT = 62 + PM_ECC_ERROR_1BIT = 63 + PM_ECC_ERROR_2BIT = 64 + PM_ADDRESS_OUT_OF_RANGE = 65 + DM_ACCESS_TO_UNAVAILABLE = 66 + LOCK_ACCESS_TO_UNAVAILABLE = 67 + INSTR_WARNING = 68 + INSTR_ERROR = 69 + DECOMPRESSION_UNDERFLOW = 70 + STREAM_SWITCH_PORT_PARITY_ERROR = 71 + PROCESSOR_BUS_ERROR = 72 + GROUP_STREAM_SWITCH = 73 + PORT_IDLE_0 = 74 + PORT_RUNNING_0 = 75 + PORT_STALLED_0 = 76 + PORT_TLAST_0 = 77 + PORT_IDLE_1 = 78 + PORT_RUNNING_1 = 79 + PORT_STALLED_1 = 80 + PORT_TLAST_1 = 81 + PORT_IDLE_2 = 82 + PORT_RUNNING_2 = 83 + PORT_STALLED_2 = 84 + PORT_TLAST_2 = 85 + PORT_IDLE_3 = 86 + PORT_RUNNING_3 = 87 + PORT_STALLED_3 = 88 + PORT_TLAST_3 = 89 + PORT_IDLE_4 = 90 + PORT_RUNNING_4 = 91 + PORT_STALLED_4 = 92 + PORT_TLAST_4 = 93 + PORT_IDLE_5 = 94 + PORT_RUNNING_5 = 95 + PORT_STALLED_5 = 96 + PORT_TLAST_5 = 97 + PORT_IDLE_6 = 98 + PORT_RUNNING_6 = 99 + PORT_STALLED_6 = 100 + PORT_TLAST_6 = 101 + PORT_IDLE_7 = 102 + PORT_RUNNING_7 = 103 + PORT_STALLED_7 = 104 + PORT_TLAST_7 = 105 + GROUP_BROADCAST = 106 + BROADCAST_0 = 107 + BROADCAST_1 = 108 + BROADCAST_2 = 109 + BROADCAST_3 = 110 + BROADCAST_4 = 111 + BROADCAST_5 = 112 + BROADCAST_6 = 113 + BROADCAST_7 = 114 + BROADCAST_8 = 115 + BROADCAST_9 = 116 + BROADCAST_10 = 117 + BROADCAST_11 = 118 + BROADCAST_12 = 119 + BROADCAST_13 = 120 + BROADCAST_14 = 121 + BROADCAST_15 = 122 + GROUP_USER_EVENT = 123 + USER_EVENT_0 = 124 + USER_EVENT_1 = 125 + USER_EVENT_2 = 126 + USER_EVENT_3 = 127 + + +class MemEvent(Enum): + NONE = 0 + TRUE = 1 + GROUP_0 = 2 + TIMER_SYNC = 3 + TIMER_VALUE_REACHED = 4 + PERF_CNT_0 = 5 + PERF_CNT_1 = 6 + COMBO_EVENT_0 = 7 + COMBO_EVENT_1 = 8 + COMBO_EVENT_2 = 9 + COMBO_EVENT_3 = 10 + EDGE_DETECTION_EVENT_0 = 11 + EDGE_DETECTION_EVENT_1 = 12 + rsvd_13 = 13 + rsvd_14 = 14 + GROUP_WATCHPOINT = 15 + WATCHPOINT_0 = 16 + WATCHPOINT_1 = 17 + GROUP_DMA_ACTIVITY = 18 + DMA_S2MM_0_START_TASK = 19 + DMA_S2MM_1_START_TASK = 20 + DMA_MM2S_0_START_TASK = 21 + DMA_MM2S_1_START_TASK = 22 + DMA_S2MM_0_FINISHED_BD = 23 + DMA_S2MM_1_FINISHED_BD = 24 + DMA_MM2S_0_FINISHED_BD = 25 + DMA_MM2S_1_FINISHED_BD = 26 + DMA_S2MM_0_FINISHED_TASK = 27 + DMA_S2MM_1_FINISHED_TASK = 28 + DMA_MM2S_0_FINISHED_TASK = 29 + DMA_MM2S_1_FINISHED_TASK = 30 + DMA_S2MM_0_STALLED_LOCK = 31 + DMA_S2MM_1_STALLED_LOCK = 32 + DMA_MM2S_0_STALLED_LOCK = 33 + DMA_MM2S_1_STALLED_LOCK = 34 + DMA_S2MM_0_STREAM_STARVATION = 35 + DMA_S2MM_1_STREAM_STARVATION = 36 + DMA_MM2S_0_STREAM_BACKPRESSURE = 37 + DMA_MM2S_1_STREAM_BACKPRESSURE = 38 + DMA_S2MM_0_MEMORY_BACKPRESSURE = 39 + DMA_S2MM_1_MEMORY_BACKPRESSURE = 40 + DMA_MM2S_0_MEMORY_STARVATION = 41 + DMA_MM2S_1_MEMORY_STARVATION = 42 + GROUP_LOCK = 43 + LOCK_SEL0_ACQ_EQ = 44 + LOCK_SEL0_ACQ_GE = 45 + LOCK_0_REL = 46 + LOCK_SEL0_EQUAL_TO_VALUE = 47 + LOCK_SEL1_ACQ_EQ = 48 + LOCK_SEL1_ACQ_GE = 49 + LOCK_1_REL = 50 + LOCK_SEL1_EQUAL_TO_VALUE = 51 + LOCK_SEL2_ACQ_EQ = 52 + LOCK_SEL2_ACQ_GE = 53 + LOCK_2_REL = 54 + LOCK_SEL2_EQUAL_TO_VALUE = 55 + LOCK_SEL3_ACQ_EQ = 56 + LOCK_SEL3_ACQ_GE = 57 + LOCK_3_REL = 58 + LOCK_SEL3_EQUAL_TO_VALUE = 59 + LOCK_SEL4_ACQ_EQ = 60 + LOCK_SEL4_ACQ_GE = 61 + LOCK_4_REL = 62 + LOCK_SEL4_EQUAL_TO_VALUE = 63 + LOCK_SEL5_ACQ_EQ = 64 + LOCK_SEL5_ACQ_GE = 65 + LOCK_5_REL = 66 + LOCK_SEL5_EQUAL_TO_VALUE = 67 + LOCK_SEL6_ACQ_EQ = 68 + LOCK_SEL6_ACQ_GE = 69 + LOCK_6_REL = 70 + LOCK_SEL6_EQUAL_TO_VALUE = 71 + LOCK_SEL7_ACQ_EQ = 72 + LOCK_SEL7_ACQ_GE = 73 + LOCK_7_REL = 74 + LOCK_SEL7_EQUAL_TO_VALUE = 75 + GROUP_MEMORY_CONFLICT = 76 + CONFLICT_DM_BANK_0 = 77 + CONFLICT_DM_BANK_1 = 78 + CONFLICT_DM_BANK_2 = 79 + CONFLICT_DM_BANK_3 = 80 + CONFLICT_DM_BANK_4 = 81 + CONFLICT_DM_BANK_5 = 82 + CONFLICT_DM_BANK_6 = 83 + CONFLICT_DM_BANK_7 = 84 + rsvd_85 = 85 + GROUP_ERRORS = 86 + DM_ECC_ERROR_SCRUB_CORRECTED = 87 + DM_ECC_ERROR_SCRUB_2BIT = 88 + DM_ECC_ERROR_1BIT = 89 + DM_ECC_ERROR_2BIT = 90 + DM_PARITY_ERROR_BANK_2 = 91 + DM_PARITY_ERROR_BANK_3 = 92 + DM_PARITY_ERROR_BANK_4 = 93 + DM_PARITY_ERROR_BANK_5 = 94 + DM_PARITY_ERROR_BANK_6 = 95 + DM_PARITY_ERROR_BANK_7 = 96 + DMA_S2MM_0_ERROR = 97 + DMA_S2MM_1_ERROR = 98 + DMA_MM2S_0_ERROR = 99 + DMA_MM2S_1_ERROR = 100 + LOCK_ERROR = 101 + DMA_TASK_TOKEN_STALL = 102 + rsvd_103 = 103 + rsvd_104 = 104 + rsvd_105 = 105 + GROUP_BROADCAST = 106 + BROADCAST_0 = 107 + BROADCAST_1 = 108 + BROADCAST_2 = 109 + BROADCAST_3 = 110 + BROADCAST_4 = 111 + BROADCAST_5 = 112 + BROADCAST_6 = 113 + BROADCAST_7 = 114 + BROADCAST_8 = 115 + BROADCAST_9 = 116 + BROADCAST_10 = 117 + BROADCAST_11 = 118 + BROADCAST_12 = 119 + BROADCAST_13 = 120 + BROADCAST_14 = 121 + BROADCAST_15 = 122 + GROUP_USER_EVENT = 123 + USER_EVENT_0 = 124 + USER_EVENT_1 = 125 + USER_EVENT_2 = 126 + USER_EVENT_3 = 127 + + +class ShimTileEvent(Enum): + NONE = 0 + TRUE = 1 + GROUP_0 = 2 + TIMER_SYNC = 3 + TIMER_VALUE_REACHED = 4 + PERF_CNT_0 = 5 + PERF_CNT_1 = 6 + COMBO_EVENT_0 = 7 + COMBO_EVENT_1 = 8 + COMBO_EVENT_2 = 9 + COMBO_EVENT_3 = 10 + EDGE_DETECTION_EVENT_0 = 11 + EDGE_DETECTION_EVENT_1 = 12 + GROUP_DMA_ACTIVITY = 13 + DMA_S2MM_0_START_TASK = 14 + DMA_S2MM_1_START_TASK = 15 + DMA_MM2S_0_START_TASK = 16 + DMA_MM2S_1_START_TASK = 17 + DMA_S2MM_0_FINISHED_BD = 18 + DMA_S2MM_1_FINISHED_BD = 19 + DMA_MM2S_0_FINISHED_BD = 20 + DMA_MM2S_1_FINISHED_BD = 21 + DMA_S2MM_0_FINISHED_TASK = 22 + DMA_S2MM_1_FINISHED_TASK = 23 + DMA_MM2S_0_FINISHED_TASK = 24 + DMA_MM2S_1_FINISHED_TASK = 25 + DMA_S2MM_0_STALLED_LOCK = 26 + DMA_S2MM_1_STALLED_LOCK = 27 + DMA_MM2S_0_STALLED_LOCK = 28 + DMA_MM2S_1_STALLED_LOCK = 29 + DMA_S2MM_0_STREAM_STARVATION = 30 + DMA_S2MM_1_STREAM_STARVATION = 31 + DMA_MM2S_0_STREAM_BACKPRESSURE = 32 + DMA_MM2S_1_STREAM_BACKPRESSURE = 33 + DMA_S2MM_0_MEMORY_BACKPRESSURE = 34 + DMA_S2MM_1_MEMORY_BACKPRESSURE = 35 + DMA_MM2S_0_MEMORY_STARVATION = 36 + DMA_MM2S_1_MEMORY_STARVATION = 37 + GROUP_LOCK = 38 + LOCK_0_ACQ_EQ = 39 + LOCK_0_ACQ_GE = 40 + LOCK_0_REL = 41 + LOCK_0_EQUAL_TO_VALUE = 42 + LOCK_1_ACQ_EQ = 43 + LOCK_1_ACQ_GE = 44 + LOCK_1_REL = 45 + LOCK_1_EQUAL_TO_VALUE = 46 + LOCK_2_ACQ_EQ = 47 + LOCK_2_ACQ_GE = 48 + LOCK_2_REL = 49 + LOCK_2_EQUAL_TO_VALUE = 50 + LOCK_3_ACQ_EQ = 51 + LOCK_3_ACQ_GE = 52 + LOCK_3_REL = 53 + LOCK_3_EQUAL_TO_VALUE = 54 + LOCK_4_ACQ_EQ = 55 + LOCK_4_ACQ_GE = 56 + LOCK_4_REL = 57 + LOCK_4_EQUAL_TO_VALUE = 58 + LOCK_5_ACQ_EQ = 59 + LOCK_5_ACQ_GE = 60 + LOCK_5_REL = 61 + LOCK_5_EQUAL_TO_VALUE = 62 + GROUP_ERRORS = 63 + AXI_MM_SLAVE_ERROR = 64 + CONTROL_PKT_ERROR = 65 + STREAM_SWITCH_PARITY_ERROR = 66 + AXI_MM_DECODE_NSU_ERROR = 67 + AXI_MM_SLAVE_NSU_ERROR = 68 + AXI_MM_UNSUPPORTED_TRAFFIC = 69 + AXI_MM_UNSECURE_ACCESS_IN_SECURE_MODE = 70 + AXI_MM_BYTE_STROBE_ERROR = 71 + DMA_S2MM_ERROR = 72 + DMA_MM2S_ERROR = 73 + LOCK_ERROR = 74 + DMA_TASK_TOKEN_STALL = 75 + GROUP_STREAM_SWITCH = 76 + PORT_IDLE_0 = 77 + PORT_RUNNING_0 = 78 + PORT_STALLED_0 = 79 + PORT_TLAST_0 = 80 + PORT_IDLE_1 = 81 + PORT_RUNNING_1 = 82 + PORT_STALLED_1 = 83 + PORT_TLAST_1 = 84 + PORT_IDLE_2 = 85 + PORT_RUNNING_2 = 86 + PORT_STALLED_2 = 87 + PORT_TLAST_2 = 88 + PORT_IDLE_3 = 89 + PORT_RUNNING_3 = 90 + PORT_STALLED_3 = 91 + PORT_TLAST_3 = 92 + PORT_IDLE_4 = 93 + PORT_RUNNING_4 = 94 + PORT_STALLED_4 = 95 + PORT_TLAST_4 = 96 + PORT_IDLE_5 = 97 + PORT_RUNNING_5 = 98 + PORT_STALLED_5 = 99 + PORT_TLAST_5 = 100 + PORT_IDLE_6 = 101 + PORT_RUNNING_6 = 102 + PORT_STALLED_6 = 103 + PORT_TLAST_6 = 104 + PORT_IDLE_7 = 105 + PORT_RUNNING_7 = 106 + PORT_STALLED_7 = 107 + PORT_TLAST_7 = 108 + GROUP_BROADCAST_A = 109 + BROADCAST_A_0 = 110 + BROADCAST_A_1 = 111 + BROADCAST_A_2 = 112 + BROADCAST_A_3 = 113 + BROADCAST_A_4 = 114 + BROADCAST_A_5 = 115 + BROADCAST_A_6 = 116 + BROADCAST_A_7 = 117 + BROADCAST_A_8 = 118 + BROADCAST_A_9 = 119 + BROADCAST_A_10 = 120 + BROADCAST_A_11 = 121 + BROADCAST_A_12 = 122 + BROADCAST_A_13 = 123 + BROADCAST_A_14 = 124 + BROADCAST_A_15 = 125 + USER_EVENT_0 = 126 + USER_EVENT_1 = 127 + + +class MemTileEvent(Enum): + NONE = 0 + TRUE = 1 + GROUP_0 = 2 + TIMER_SYNC = 3 + TIMER_VALUE_REACHED = 4 + PERF_CNT0_EVENT = 5 + PERF_CNT1_EVENT = 6 + PERF_CNT2_EVENT = 7 + PERF_CNT3_EVENT = 8 + COMBO_EVENT_0 = 9 + COMBO_EVENT_1 = 10 + COMBO_EVENT_2 = 11 + COMBO_EVENT_3 = 12 + EDGE_DETECTION_EVENT_0 = 13 + EDGE_DETECTION_EVENT_1 = 14 + GROUP_WATCHPOINT = 15 + WATCHPOINT_0 = 16 + WATCHPOINT_1 = 17 + WATCHPOINT_2 = 18 + WATCHPOINT_3 = 19 + GROUP_DMA_ACTIVITY = 20 + DMA_S2MM_SEL0_START_TASK = 21 + DMA_S2MM_SEL1_START_TASK = 22 + DMA_MM2S_SEL0_START_TASK = 23 + DMA_MM2S_SEL1_START_TASK = 24 + DMA_S2MM_SEL0_FINISHED_BD = 25 + DMA_S2MM_SEL1_FINISHED_BD = 26 + DMA_MM2S_SEL0_FINISHED_BD = 27 + DMA_MM2S_SEL1_FINISHED_BD = 28 + DMA_S2MM_SEL0_FINISHED_TASK = 29 + DMA_S2MM_SEL1_FINISHED_TASK = 30 + DMA_MM2S_SEL0_FINISHED_TASK = 31 + DMA_MM2S_SEL1_FINISHED_TASK = 32 + DMA_S2MM_SEL0_STALLED_LOCK = 33 + DMA_S2MM_SEL1_STALLED_LOCK = 34 + DMA_MM2S_SEL0_STALLED_LOCK = 35 + DMA_MM2S_SEL1_STALLED_LOCK = 36 + DMA_S2MM_SEL0_STREAM_STARVATION = 37 + DMA_S2MM_SEL1_STREAM_STARVATION = 38 + DMA_MM2S_SEL0_STREAM_BACKPRESSURE = 39 + DMA_MM2S_SEL1_STREAM_BACKPRESSURE = 40 + DMA_S2MM_SEL0_MEMORY_BACKPRESSURE = 41 + DMA_S2MM_SEL1_MEMORY_BACKPRESSURE = 42 + DMA_MM2S_SEL0_MEMORY_STARVATION = 43 + DMA_MM2S_SEL1_MEMORY_STARVATION = 44 + GROUP_LOCK = 45 + LOCK_SEL0_ACQ_EQ = 46 + LOCK_SEL0_ACQ_GE = 47 + LOCK_SEL0_REL = 48 + LOCK_SEL0_EQUAL_TO_VALUE = 49 + LOCK_SEL1_ACQ_EQ = 50 + LOCK_SEL1_ACQ_GE = 51 + LOCK_SEL1_REL = 52 + LOCK_SEL1_EQUAL_TO_VALUE = 53 + LOCK_SEL2_ACQ_EQ = 54 + LOCK_SEL2_ACQ_GE = 55 + LOCK_SEL2_REL = 56 + LOCK_SEL2_EQUAL_TO_VALUE = 57 + LOCK_SEL3_ACQ_EQ = 58 + LOCK_SEL3_ACQ_GE = 59 + LOCK_SEL3_REL = 60 + LOCK_SEL3_EQUAL_TO_VALUE = 61 + LOCK_SEL4_ACQ_EQ = 62 + LOCK_SEL4_ACQ_GE = 63 + LOCK_SEL4_REL = 64 + LOCK_SEL4_EQUAL_TO_VALUE = 65 + LOCK_SEL5_ACQ_EQ = 66 + LOCK_SEL5_ACQ_GE = 67 + LOCK_SEL5_REL = 68 + LOCK_SEL5_EQUAL_TO_VALUE = 69 + LOCK_SEL6_ACQ_EQ = 70 + LOCK_SEL6_ACQ_GE = 71 + LOCK_SEL6_REL = 72 + LOCK_SEL6_EQUAL_TO_VALUE = 73 + LOCK_SEL7_ACQ_EQ = 74 + LOCK_SEL7_ACQ_GE = 75 + LOCK_SEL7_REL = 76 + LOCK_SEL7_EQUAL_TO_VALUE = 77 + GROUP_STREAM_SWITCH = 78 + PORT_IDLE_0 = 79 + PORT_RUNNING_0 = 80 + PORT_STALLED_0 = 81 + PORT_TLAST_0 = 82 + PORT_IDLE_1 = 83 + PORT_RUNNING_1 = 84 + PORT_STALLED_1 = 85 + PORT_TLAST_1 = 86 + PORT_IDLE_2 = 87 + PORT_RUNNING_2 = 88 + PORT_STALLED_2 = 89 + PORT_TLAST_2 = 90 + PORT_IDLE_3 = 91 + PORT_RUNNING_3 = 92 + PORT_STALLED_3 = 93 + PORT_TLAST_3 = 94 + PORT_IDLE_4 = 95 + PORT_RUNNING_4 = 96 + PORT_STALLED_4 = 97 + PORT_TLAST_4 = 98 + PORT_IDLE_5 = 99 + PORT_RUNNING_5 = 100 + PORT_STALLED_5 = 101 + PORT_TLAST_5 = 102 + PORT_IDLE_6 = 103 + PORT_RUNNING_6 = 104 + PORT_STALLED_6 = 105 + PORT_TLAST_6 = 106 + PORT_IDLE_7 = 107 + PORT_RUNNING_7 = 108 + PORT_STALLED_7 = 109 + PORT_TLAST_7 = 110 + GROUP_MEMORY_CONFLICT = 111 + CONFLICT_DM_BANK_0 = 112 + CONFLICT_DM_BANK_1 = 113 + CONFLICT_DM_BANK_2 = 114 + CONFLICT_DM_BANK_3 = 115 + CONFLICT_DM_BANK_4 = 116 + CONFLICT_DM_BANK_5 = 117 + CONFLICT_DM_BANK_6 = 118 + CONFLICT_DM_BANK_7 = 119 + CONFLICT_DM_BANK_8 = 120 + CONFLICT_DM_BANK_9 = 121 + CONFLICT_DM_BANK_10 = 122 + CONFLICT_DM_BANK_11 = 123 + CONFLICT_DM_BANK_12 = 124 + CONFLICT_DM_BANK_13 = 125 + CONFLICT_DM_BANK_14 = 126 + CONFLICT_DM_BANK_15 = 127 + GROUP_ERRORS = 128 + DM_ECC_ERROR_SCRUB_CORRECTED = 129 + DM_ECC_ERROR_SCRUB_2BIT = 130 + DM_ECC_ERROR_1BIT = 131 + DM_ECC_ERROR_2BIT = 132 + DMA_S2MM_ERROR = 133 + DMA_MM2S_ERROR = 134 + STREAM_SWITCH_PARITY_ERROR = 135 + STREAM_PKT_ERROR = 136 + CONTROL_PKT_ERROR = 137 + AXI_MM_SLAVE_ERROR = 138 + LOCK_ERROR = 139 + DMA_TASK_TOKEN_STALL = 140 + GROUP_BROADCAST = 141 + BROADCAST_0 = 142 + BROADCAST_1 = 143 + BROADCAST_2 = 144 + BROADCAST_3 = 145 + BROADCAST_4 = 146 + BROADCAST_5 = 147 + BROADCAST_6 = 148 + BROADCAST_7 = 149 + BROADCAST_8 = 150 + BROADCAST_9 = 151 + BROADCAST_10 = 152 + BROADCAST_11 = 153 + BROADCAST_12 = 154 + BROADCAST_13 = 155 + BROADCAST_14 = 156 + BROADCAST_15 = 157 + GROUP_USER_EVENT = 158 + USER_EVENT_0 = 159 + USER_EVENT_1 = 160 diff --git a/python/iron/worker.py b/python/iron/worker.py index 9791f65c72b..f1202d768cf 100644 --- a/python/iron/worker.py +++ b/python/iron/worker.py @@ -4,7 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # -# (c) Copyright 2024 Advanced Micro Devices, Inc. +# (c) Copyright 2024-2025 Advanced Micro Devices, Inc. + import contextvars import sys from typing import Callable @@ -19,6 +20,7 @@ from .kernel import Kernel, ExternalFunction from .globalbuffer import GlobalBuffer from .resolvable import Resolvable +from . import trace class Worker(ObjectFifoEndpoint): @@ -53,7 +55,7 @@ def __init__( stack_size (int, optional): The stack_size in bytes to be allocated for the worker. Defaults to 1024 bytes. allocation_scheme (str, optional): The memory allocation scheme to use for the Worker, either 'basic-sequential' or 'bank-aware'. If None, defaults to bank-aware. Will override any allocation scheme set on the tile given as placement. - trace (int, optional): If >0, enable tracing for this worker. + trace (int, optional): If >0, enable tracing for this worker. If None and tracing is active, automatically set to 1. Raises: ValueError: Parameters are validated. @@ -64,6 +66,12 @@ def __init__( self.allocation_scheme = allocation_scheme if allocation_scheme: self._tile.allocation_scheme = allocation_scheme + + # Auto-enable tracing if not explicitly set and tracing is active + if trace is None: + if trace._get_trace_active(): + trace = 1 + self.trace = trace self.trace_events = trace_events