Skip to content

Linearize read indices pipeline#1196

Open
willghatch wants to merge 3 commits intomainfrom
users/willghatch/linearize
Open

Linearize read indices pipeline#1196
willghatch wants to merge 3 commits intomainfrom
users/willghatch/linearize

Conversation

@willghatch
Copy link
Contributor

The main goal of this is to be able to have memory addresses for reads in a loop be simplified to start + IV * stride, with each of those values either being constant or at least able to be hoisted outside of the loop body.

Adds a pre-codegen pipeline that flattens N-dimensional read addresses into 1-dimensional LINEAR_INDEX accesses:

  • flatten_read_indices: rewrites mapped reads to use a single flat offset
  • annotate_iv_strides: extracts constant IV strides for loop-carried reads
  • Codegen LINEAR_INDEX paths for both vector loads and GatherToLDS
  • Removes the old _try_iv_split_offset codegen approach in favor of the new pipeline-based linearization
  • Helper functions (mem_simplify, linearize_dims, _infer_floor_to_exact) in mapping_utils for symbolic floor/Mod cancellation
  • Adjust bounds for linearized reads

This adds some new lit tests that show that with our mxfp4 shuffle layout we can generate linearized reads with constant stride.

This changes a ton of other lit tests. Disclaimer: I was asked to get this PR up ASAP, and this is a ton of churn in the lit tests, and I have not yet validated that they are all correct.

The main goal of this is to be able to have memory addresses for reads in a loop be simplified to `start + IV * stride`, with each of those values either being constant or at least able to be hoisted outside of the loop body.

Adds a pre-codegen pipeline that flattens N-dimensional read addresses into 1-dimensional LINEAR_INDEX accesses:
- flatten_read_indices: rewrites mapped reads to use a single flat offset
- annotate_iv_strides: extracts constant IV strides for loop-carried reads
- Codegen LINEAR_INDEX paths for both vector loads and GatherToLDS
- Removes the old _try_iv_split_offset codegen approach in favor of the new pipeline-based linearization
- Helper functions (mem_simplify, linearize_dims, _infer_floor_to_exact) in mapping_utils for symbolic floor/Mod cancellation
- Adjust bounds for linearized reads

This adds some new lit tests that show that with our mxfp4 shuffle layout we can generate linearized reads with constant stride.

This changes a ton of other lit tests.  Disclaimer:  I was asked to get this PR up ASAP, and this is a ton of churn in the lit tests, and I have not yet validated that they are all correct.

Signed-off-by: William G Hatch <[email protected]>
flat_offset = idx_seq.start
iv_stride_val = idx_seq.stride

precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine for now, but we should get rid of it one day.

Comment on lines +939 to +1096
if precomputed_mask_expr is not None and not buffer_ops_enabled:
mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr)
mask_vec_type = VectorType.get(
[elements_per_thread], IntegerType.get_signless(1)
)
if mask.type != mask_vec_type:
mask = vector_d.broadcast(mask_vec_type, mask)
else:
mask = _build_mask(emitter, index, elements_per_thread, bounds)

is_global = get_custom(memory).type.address_space != SHARED_ADDRESS_SPACE
use_llvm_load = flags != MemoryAccessFlags.NONE

if (
is_global
and not use_llvm_load
and not read_meets_hw_transpose_requirements(
get_custom(node), emitter.constraints, emitter.options.target
)
):
subs_map = add_emitter_subs(emitter, dynamic_vals_map_start)

kb_type = MemRefType(kb_src.type)
phys_strides, _ = kb_type.get_strides_and_offset()
dyn_sentinel = ShapedType.get_dynamic_stride_or_offset()
if any(s == dyn_sentinel for s in phys_strides):
sym_strides = list(
strides_from_symbolic_shape(
IndexingContext.current(),
input_shape,
allow_mixed_shapes=True,
)
)
else:
sym_strides = [sympy.Integer(s) for s in phys_strides]

ip = InsertionPoint.current
owner = ip.block.owner
is_in_loop = (
not isinstance(owner, func_d.FuncOp) and owner.name == "scf.for"
)
has_iv = iv_stride_val != 0 and is_in_loop
hoist_ip = InsertionPoint(owner) if is_in_loop else None

if hoist_ip is not None:
with hoist_ip:
strides_vals = [gen_sympy_index(subs_map, s) for s in sym_strides]
zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(
sym_strides
)
lin_src, _ = _linearize_memref(
kb_src, zero_indices, zero_indices, strides_vals
)
if buffer_ops_enabled:
valid_bytes = _compute_valid_bytes(
lin_src,
element_type,
input_shape,
emitter,
)
lin_src = _cast_buffer_and_encode_stride(
lin_src,
strides_vals,
element_type,
valid_bytes,
)
else:
strides_vals = [gen_sympy_index(subs_map, s) for s in sym_strides]
zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(sym_strides)
lin_src, _ = _linearize_memref(
kb_src, zero_indices, zero_indices, strides_vals
)
if buffer_ops_enabled:
valid_bytes = _compute_valid_bytes(
lin_src,
element_type,
input_shape,
emitter,
)
lin_src = _cast_buffer_and_encode_stride(
lin_src,
strides_vals,
element_type,
valid_bytes,
)

total_offset = gen_sympy_index(subs_map, flat_offset)

if mask is None:
result = vector_d.load(vector_type, lin_src, [total_offset])
else:
el_type = vector_type.element_type
zero = arith_d.constant(el_type, get_constant_attr(0, el_type))
passthru = vector_d.broadcast(vector_type, zero)
result = vector_d.maskedload(
vector_type, lin_src, [total_offset], mask, passthru
)
emitter.bind_node_proxy(node, IRProxyValue(result))
return

# Global reads that didn't take the fast path above (e.g. HW
# transpose candidates or LLVM-load flagged reads): linearize
# the memref and do a simple 1-D load so the flat index works.
if is_global:
subs_map = add_emitter_subs(emitter, dynamic_vals_map_start)
kb_type = MemRefType(kb_src.type)
phys_strides, _ = kb_type.get_strides_and_offset()
dyn_sentinel = ShapedType.get_dynamic_stride_or_offset()
if any(s == dyn_sentinel for s in phys_strides):
sym_strides = list(
strides_from_symbolic_shape(
IndexingContext.current(),
input_shape,
allow_mixed_shapes=True,
)
)
else:
sym_strides = [sympy.Integer(s) for s in phys_strides]
strides_vals = [gen_sympy_index(subs_map, s) for s in sym_strides]
zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(sym_strides)
lin_src, _ = _linearize_memref(
kb_src, zero_indices, zero_indices, strides_vals
)
total_offset = gen_sympy_index(subs_map, flat_offset)
if mask is None:
result = vector_d.load(vector_type, lin_src, [total_offset])
else:
el_type = vector_type.element_type
zero = arith_d.constant(el_type, get_constant_attr(0, el_type))
passthru = vector_d.broadcast(vector_type, zero)
result = vector_d.maskedload(
vector_type, lin_src, [total_offset], mask, passthru
)
emitter.bind_node_proxy(node, IRProxyValue(result))
return

# Shared memory paths.
subs_map = add_emitter_subs(emitter, dynamic_vals_map_start)
flat_idx_val = gen_sympy_index(subs_map, flat_offset)
start_indices = [flat_idx_val]
start_indices_wg = [flat_idx_val]
start_indices_th = [arith_d.constant(IndexType.get(), 0)]

result = _create_vec_read_write(
emitter,
input_shape,
kb_src,
None,
vector_type,
start_indices,
start_indices_wg,
start_indices_th,
elements_per_thread,
get_custom(memory),
mask,
node_index=index,
)
emitter.bind_node_proxy(node, IRProxyValue(result))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably split this logic into multiple functions

@@ -0,0 +1,234 @@
# Copyright 2025 The IREE Authors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2026 everywhere

return True


def _expand_mod(expr: sympy.Expr) -> sympy.Expr:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really hope we will get rid of this when we switch to the new simplifier.

Three fixes to the flatten_read_indices pass:

1. Skip reads with MemoryAccessFlags (VOLATILE, NONTEMPORAL).  The
   LINEAR_INDEX codegen fallback path uses vector.maskedload which
   drops volatile semantics.  This caused incorrect streamk partial
   buffer synchronization (stale spinlock reads).

2. Use physical_layout.shape for stride computation when a
   MemoryLayout is present, matching the strides the emitter
   creates for the memref via reinterpret_cast.

3. Use physical (post-mapping) start expressions as bound keys in
   _convert_bounds, falling back to the original index when the
   physical start contains $dynamic_val symbols that are only
   resolvable through the mapping at codegen time.

Made-with: Cursor

Signed-off-by: William G Hatch <[email protected]>
Replace the water-specific emit_water_dialect flag with a general
linearize_reads option (default True).  Set it to False in:

- water_e2e_test: water-opt does not yet understand $LINEAR_INDEX
- waveasm 256x224x256 dynamic+bufops: linearized reads push VGPR
  count past the 256-register limit; disabling linearization lets
  the test pass instead of needing an xfail

Signed-off-by: William G Hatch <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants