Open
Conversation
The main goal of this is to be able to have memory addresses for reads in a loop be simplified to `start + IV * stride`, with each of those values either being constant or at least able to be hoisted outside of the loop body. Adds a pre-codegen pipeline that flattens N-dimensional read addresses into 1-dimensional LINEAR_INDEX accesses: - flatten_read_indices: rewrites mapped reads to use a single flat offset - annotate_iv_strides: extracts constant IV strides for loop-carried reads - Codegen LINEAR_INDEX paths for both vector loads and GatherToLDS - Removes the old _try_iv_split_offset codegen approach in favor of the new pipeline-based linearization - Helper functions (mem_simplify, linearize_dims, _infer_floor_to_exact) in mapping_utils for symbolic floor/Mod cancellation - Adjust bounds for linearized reads This adds some new lit tests that show that with our mxfp4 shuffle layout we can generate linearized reads with constant stride. This changes a ton of other lit tests. Disclaimer: I was asked to get this PR up ASAP, and this is a ton of churn in the lit tests, and I have not yet validated that they are all correct. Signed-off-by: William G Hatch <[email protected]>
Hardcode84
reviewed
Mar 26, 2026
| flat_offset = idx_seq.start | ||
| iv_stride_val = idx_seq.stride | ||
|
|
||
| precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None) |
Contributor
There was a problem hiding this comment.
It's fine for now, but we should get rid of it one day.
Comment on lines
+939
to
+1096
| if precomputed_mask_expr is not None and not buffer_ops_enabled: | ||
| mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr) | ||
| mask_vec_type = VectorType.get( | ||
| [elements_per_thread], IntegerType.get_signless(1) | ||
| ) | ||
| if mask.type != mask_vec_type: | ||
| mask = vector_d.broadcast(mask_vec_type, mask) | ||
| else: | ||
| mask = _build_mask(emitter, index, elements_per_thread, bounds) | ||
|
|
||
| is_global = get_custom(memory).type.address_space != SHARED_ADDRESS_SPACE | ||
| use_llvm_load = flags != MemoryAccessFlags.NONE | ||
|
|
||
| if ( | ||
| is_global | ||
| and not use_llvm_load | ||
| and not read_meets_hw_transpose_requirements( | ||
| get_custom(node), emitter.constraints, emitter.options.target | ||
| ) | ||
| ): | ||
| subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) | ||
|
|
||
| kb_type = MemRefType(kb_src.type) | ||
| phys_strides, _ = kb_type.get_strides_and_offset() | ||
| dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() | ||
| if any(s == dyn_sentinel for s in phys_strides): | ||
| sym_strides = list( | ||
| strides_from_symbolic_shape( | ||
| IndexingContext.current(), | ||
| input_shape, | ||
| allow_mixed_shapes=True, | ||
| ) | ||
| ) | ||
| else: | ||
| sym_strides = [sympy.Integer(s) for s in phys_strides] | ||
|
|
||
| ip = InsertionPoint.current | ||
| owner = ip.block.owner | ||
| is_in_loop = ( | ||
| not isinstance(owner, func_d.FuncOp) and owner.name == "scf.for" | ||
| ) | ||
| has_iv = iv_stride_val != 0 and is_in_loop | ||
| hoist_ip = InsertionPoint(owner) if is_in_loop else None | ||
|
|
||
| if hoist_ip is not None: | ||
| with hoist_ip: | ||
| strides_vals = [gen_sympy_index(subs_map, s) for s in sym_strides] | ||
| zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( | ||
| sym_strides | ||
| ) | ||
| lin_src, _ = _linearize_memref( | ||
| kb_src, zero_indices, zero_indices, strides_vals | ||
| ) | ||
| if buffer_ops_enabled: | ||
| valid_bytes = _compute_valid_bytes( | ||
| lin_src, | ||
| element_type, | ||
| input_shape, | ||
| emitter, | ||
| ) | ||
| lin_src = _cast_buffer_and_encode_stride( | ||
| lin_src, | ||
| strides_vals, | ||
| element_type, | ||
| valid_bytes, | ||
| ) | ||
| else: | ||
| strides_vals = [gen_sympy_index(subs_map, s) for s in sym_strides] | ||
| zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(sym_strides) | ||
| lin_src, _ = _linearize_memref( | ||
| kb_src, zero_indices, zero_indices, strides_vals | ||
| ) | ||
| if buffer_ops_enabled: | ||
| valid_bytes = _compute_valid_bytes( | ||
| lin_src, | ||
| element_type, | ||
| input_shape, | ||
| emitter, | ||
| ) | ||
| lin_src = _cast_buffer_and_encode_stride( | ||
| lin_src, | ||
| strides_vals, | ||
| element_type, | ||
| valid_bytes, | ||
| ) | ||
|
|
||
| total_offset = gen_sympy_index(subs_map, flat_offset) | ||
|
|
||
| if mask is None: | ||
| result = vector_d.load(vector_type, lin_src, [total_offset]) | ||
| else: | ||
| el_type = vector_type.element_type | ||
| zero = arith_d.constant(el_type, get_constant_attr(0, el_type)) | ||
| passthru = vector_d.broadcast(vector_type, zero) | ||
| result = vector_d.maskedload( | ||
| vector_type, lin_src, [total_offset], mask, passthru | ||
| ) | ||
| emitter.bind_node_proxy(node, IRProxyValue(result)) | ||
| return | ||
|
|
||
| # Global reads that didn't take the fast path above (e.g. HW | ||
| # transpose candidates or LLVM-load flagged reads): linearize | ||
| # the memref and do a simple 1-D load so the flat index works. | ||
| if is_global: | ||
| subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) | ||
| kb_type = MemRefType(kb_src.type) | ||
| phys_strides, _ = kb_type.get_strides_and_offset() | ||
| dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() | ||
| if any(s == dyn_sentinel for s in phys_strides): | ||
| sym_strides = list( | ||
| strides_from_symbolic_shape( | ||
| IndexingContext.current(), | ||
| input_shape, | ||
| allow_mixed_shapes=True, | ||
| ) | ||
| ) | ||
| else: | ||
| sym_strides = [sympy.Integer(s) for s in phys_strides] | ||
| strides_vals = [gen_sympy_index(subs_map, s) for s in sym_strides] | ||
| zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(sym_strides) | ||
| lin_src, _ = _linearize_memref( | ||
| kb_src, zero_indices, zero_indices, strides_vals | ||
| ) | ||
| total_offset = gen_sympy_index(subs_map, flat_offset) | ||
| if mask is None: | ||
| result = vector_d.load(vector_type, lin_src, [total_offset]) | ||
| else: | ||
| el_type = vector_type.element_type | ||
| zero = arith_d.constant(el_type, get_constant_attr(0, el_type)) | ||
| passthru = vector_d.broadcast(vector_type, zero) | ||
| result = vector_d.maskedload( | ||
| vector_type, lin_src, [total_offset], mask, passthru | ||
| ) | ||
| emitter.bind_node_proxy(node, IRProxyValue(result)) | ||
| return | ||
|
|
||
| # Shared memory paths. | ||
| subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) | ||
| flat_idx_val = gen_sympy_index(subs_map, flat_offset) | ||
| start_indices = [flat_idx_val] | ||
| start_indices_wg = [flat_idx_val] | ||
| start_indices_th = [arith_d.constant(IndexType.get(), 0)] | ||
|
|
||
| result = _create_vec_read_write( | ||
| emitter, | ||
| input_shape, | ||
| kb_src, | ||
| None, | ||
| vector_type, | ||
| start_indices, | ||
| start_indices_wg, | ||
| start_indices_th, | ||
| elements_per_thread, | ||
| get_custom(memory), | ||
| mask, | ||
| node_index=index, | ||
| ) | ||
| emitter.bind_node_proxy(node, IRProxyValue(result)) |
Contributor
There was a problem hiding this comment.
we should probably split this logic into multiple functions
| @@ -0,0 +1,234 @@ | |||
| # Copyright 2025 The IREE Authors | |||
| return True | ||
|
|
||
|
|
||
| def _expand_mod(expr: sympy.Expr) -> sympy.Expr: |
Contributor
There was a problem hiding this comment.
I really hope we will get rid of this when we switch to the new simplifier.
Three fixes to the flatten_read_indices pass: 1. Skip reads with MemoryAccessFlags (VOLATILE, NONTEMPORAL). The LINEAR_INDEX codegen fallback path uses vector.maskedload which drops volatile semantics. This caused incorrect streamk partial buffer synchronization (stale spinlock reads). 2. Use physical_layout.shape for stride computation when a MemoryLayout is present, matching the strides the emitter creates for the memref via reinterpret_cast. 3. Use physical (post-mapping) start expressions as bound keys in _convert_bounds, falling back to the original index when the physical start contains $dynamic_val symbols that are only resolvable through the mapping at codegen time. Made-with: Cursor Signed-off-by: William G Hatch <[email protected]>
Replace the water-specific emit_water_dialect flag with a general linearize_reads option (default True). Set it to False in: - water_e2e_test: water-opt does not yet understand $LINEAR_INDEX - waveasm 256x224x256 dynamic+bufops: linearized reads push VGPR count past the 256-register limit; disabling linearization lets the test pass instead of needing an xfail Signed-off-by: William G Hatch <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The main goal of this is to be able to have memory addresses for reads in a loop be simplified to
start + IV * stride, with each of those values either being constant or at least able to be hoisted outside of the loop body.Adds a pre-codegen pipeline that flattens N-dimensional read addresses into 1-dimensional LINEAR_INDEX accesses:
This adds some new lit tests that show that with our mxfp4 shuffle layout we can generate linearized reads with constant stride.
This changes a ton of other lit tests. Disclaimer: I was asked to get this PR up ASAP, and this is a ton of churn in the lit tests, and I have not yet validated that they are all correct.