Skip to content

Commit cdf3348

Browse files
committed
Remove tracing elaboratable and replace with Fragment.origins
1 parent 397e400 commit cdf3348

16 files changed

Lines changed: 163 additions & 232 deletions

naps/cores/debug/fsm_status_reg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def fsm_status_reg(platform, m, fsm):
1111
if isinstance(platform, SocPlatform):
1212
fsm_state = StatusSignal(name=f"{fsm.state.name}_reg") # TODO: use meaningful shape value here (needs deferring)
13-
def signal_fixup_hook(platform, top_fragment: Fragment, sames):
13+
def signal_fixup_hook(platform, top_fragment: Fragment):
1414
fsm_state.width = fsm.state.width
1515
fsm_state.decoder = fsm.state.decoder
1616
platform.prepare_hooks.insert(0, signal_fixup_hook)

naps/cores/peripherals/csr_bank.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010

1111
class CsrBank(Elaboratable):
12-
def __init__(self):
12+
def __init__(self, names):
1313
self.memorymap = MemoryMap()
14+
self.name = ".".join(names)
1415

1516
def reg(self, name: str, signal: _Csr):
1617
assert isinstance(signal, _Csr)
@@ -67,5 +68,5 @@ def handle_write(m, addr, data, write_done):
6768
write_done(Response.ERR)
6869

6970
m = Module()
70-
m.submodules += Peripheral(handle_read, handle_write, self.memorymap)
71+
m.submodules += Peripheral(handle_read, handle_write, self.memorymap, self.name)
7172
return m

naps/cores/peripherals/csr_bank_zynq_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
class TestAxiSlave(unittest.TestCase):
66
def check_csr_bank(self, num_csr=10, testdata=0x12345678, use_axi_interconnect=False):
77
platform = ZynqSocPlatform(SimPlatform(), use_axi_interconnect)
8-
csr_bank = CsrBank()
8+
csr_bank = CsrBank("test")
99
for i in range(num_csr):
1010
csr_bank.reg("csr#{}".format(i), ControlSignal(32))
1111

@@ -26,7 +26,7 @@ def test_csr_bank_interconnect(self):
2626

2727
def test_simple_test_csr_bank(self):
2828
platform = ZynqSocPlatform(SimPlatform())
29-
csr_bank = CsrBank()
29+
csr_bank = CsrBank("test")
3030
csr_bank.reg("csr", ControlSignal(32))
3131

3232
def testbench():

naps/cores/stream/stream_memory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from amaranth import *
2+
from amaranth.lib.memory import Memory
23
from amaranth.utils import bits_for
34

45
from naps import BasicStream, stream_transformer

naps/soc/devicetree_overlay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def devicetree_overlay(platform, overlay_name, overlay_content, placeholder_subs
1717
if not hasattr(platform, "devicetree_overlays"):
1818
platform.devicetree_overlays = dict()
1919

20-
def overlay_hook(platform, top_fragment: Fragment, sames):
20+
def overlay_hook(platform, top_fragment: Fragment):
2121
assert hasattr(top_fragment, "memorymap")
2222
memorymap = top_fragment.memorymap
2323

naps/soc/hooks.py

Lines changed: 114 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -4,104 +4,27 @@
44
from amaranth.hdl._ast import Assign, Property, Switch, Print, Operator, Slice, Part, Concat, SwitchValue, ClockSignal, ResetSignal, Initial, ValueCastable
55
from amaranth.hdl._ir import RequirePosedge
66

7+
from naps.soc.tracing_elaborate import get_elaboratable, get_module
8+
79
from .csr_types import _Csr, ControlSignal, StatusSignal, EventReg
810
from .memorymap import MemoryMap
911
from .pydriver.driver_items import DriverItem, DriverData
10-
from .tracing_elaborate import ElaboratableSames
1112
from ..util.py_serialize import is_py_serializable
1213

1314

14-
def csr_and_driver_item_hook(platform, top_fragment: Fragment, sames: ElaboratableSames):
15+
def csr_and_driver_item_hook(platform, top_fragment: Fragment):
1516
from naps.cores.peripherals.csr_bank import CsrBank
1617
already_done = []
1718

18-
def inner(fragment):
19-
elaboratable = sames.get_elaboratable(fragment)
20-
if elaboratable:
21-
class_members = [(s, getattr(elaboratable, s)) for s in dir(elaboratable)]
22-
csr_signals = [(name, member) for name, member in class_members if isinstance(member, _Csr)]
23-
24-
def get_statement_csrs(stmt):
25-
csrs = set()
26-
if stmt is None:
27-
pass
28-
# statements
29-
elif isinstance(stmt, Assign):
30-
csrs |= get_statement_csrs(stmt.lhs)
31-
csrs |= get_statement_csrs(stmt.rhs)
32-
elif isinstance(stmt, Property):
33-
csrs |= get_statement_csrs(stmt.message)
34-
csrs |= get_statement_csrs(stmt.test)
35-
elif isinstance(stmt, Switch):
36-
csrs |= get_statement_csrs(stmt.test)
37-
for _patterns, statements, _src_loc in stmt.cases:
38-
for statement in statements:
39-
csrs |= get_statement_csrs(statement)
40-
elif isinstance(stmt, Print):
41-
for chunk in stmt.message._chunks:
42-
if isinstance(chunk, tuple):
43-
value, _format_spec = chunk
44-
csrs |= get_statement_csrs(value)
45-
# Values
46-
elif isinstance(stmt, Operator):
47-
for operand in stmt.operands:
48-
csrs |= get_statement_csrs(operand)
49-
elif isinstance(stmt, Slice):
50-
csrs |= get_statement_csrs(stmt.value)
51-
elif isinstance(stmt, Part):
52-
csrs |= get_statement_csrs(stmt.value)
53-
csrs |= get_statement_csrs(stmt.offset)
54-
elif isinstance(stmt, Concat):
55-
for part in stmt.parts:
56-
csrs |= get_statement_csrs(part)
57-
elif isinstance(stmt, SwitchValue):
58-
csrs |= get_statement_csrs(stmt.test)
59-
for pattern, value in stmt.cases:
60-
csrs |= get_statement_csrs(value)
61-
elif isinstance(stmt, (ClockSignal, ResetSignal, Initial)):
62-
pass
63-
elif isinstance(stmt, _Csr):
64-
csrs.add(stmt)
65-
elif isinstance(stmt, (Signal, Const, ValueCastable)):
66-
pass
67-
else:
68-
raise AssertionError("unknown object {} of type {} in statement", stmt, type(stmt))
69-
return csrs
70-
71-
fragment_signals = set()
72-
for _domain, statements in fragment.statements.items():
73-
for stmt in statements:
74-
fragment_signals |= get_statement_csrs(stmt)
75-
76-
csr_signals += [
77-
(signal.name, signal) for signal in fragment_signals
78-
if isinstance(signal, _Csr)
79-
and signal.name != "$signal"
80-
and not any(signal is cmp_signal for name, cmp_signal in csr_signals)
81-
]
82-
83-
new_csr_signals = [(name, signal) for name, signal in csr_signals if not any(signal is done for done in already_done)]
84-
old_csr_signals = [(name, signal) for name, signal in csr_signals if any(signal is done for done in already_done)]
85-
for name, signal in new_csr_signals:
86-
already_done.append(signal)
87-
88-
mmap = fragment.memorymap = MemoryMap()
89-
90-
if new_csr_signals:
91-
m = Module()
92-
csr_bank = m.submodules.csr_bank = CsrBank()
93-
for name, signal in new_csr_signals:
94-
if isinstance(signal, (ControlSignal, StatusSignal, EventReg)):
95-
csr_bank.reg(name, signal)
96-
signal._MustUse__used = True
97-
98-
mmap.allocate_subrange(csr_bank.memorymap, name=None) # name=None means that the Memorymap will be inlined
99-
platform.to_inject_subfragments.append((m, "ignore"))
19+
def inner(fragment, names):
20+
elaboratables = get_elaboratable(fragment) or ()
10021

101-
for name, signal in old_csr_signals:
102-
mmap.add_alias(name, signal)
22+
class_members = []
23+
driver_items = []
10324

104-
driver_items = [
25+
for elaboratable in elaboratables:
26+
class_members += list(elaboratable.__dict__.items())
27+
driver_items += [
10528
(name, getattr(elaboratable, name))
10629
for name in dir(elaboratable)
10730
if isinstance(getattr(elaboratable, name), DriverItem)
@@ -111,19 +34,103 @@ def get_statement_csrs(stmt):
11134
for name in dir(elaboratable)
11235
if is_py_serializable(getattr(elaboratable, name)) and not name.startswith("_")
11336
]
114-
for name, driver_item in driver_items:
115-
fragment.memorymap.add_driver_item(name, driver_item)
37+
38+
csr_signals = [(name, member) for name, member in class_members if isinstance(member, _Csr)]
39+
40+
def get_statement_csrs(stmt):
41+
csrs = set()
42+
if stmt is None:
43+
pass
44+
# statements
45+
elif isinstance(stmt, Assign):
46+
csrs |= get_statement_csrs(stmt.lhs)
47+
csrs |= get_statement_csrs(stmt.rhs)
48+
elif isinstance(stmt, Property):
49+
csrs |= get_statement_csrs(stmt.message)
50+
csrs |= get_statement_csrs(stmt.test)
51+
elif isinstance(stmt, Switch):
52+
csrs |= get_statement_csrs(stmt.test)
53+
for _patterns, statements, _src_loc in stmt.cases:
54+
for statement in statements:
55+
csrs |= get_statement_csrs(statement)
56+
elif isinstance(stmt, Print):
57+
for chunk in stmt.message._chunks:
58+
if isinstance(chunk, tuple):
59+
value, _format_spec = chunk
60+
csrs |= get_statement_csrs(value)
61+
# Values
62+
elif isinstance(stmt, Operator):
63+
for operand in stmt.operands:
64+
csrs |= get_statement_csrs(operand)
65+
elif isinstance(stmt, Slice):
66+
csrs |= get_statement_csrs(stmt.value)
67+
elif isinstance(stmt, Part):
68+
csrs |= get_statement_csrs(stmt.value)
69+
csrs |= get_statement_csrs(stmt.offset)
70+
elif isinstance(stmt, Concat):
71+
for part in stmt.parts:
72+
csrs |= get_statement_csrs(part)
73+
elif isinstance(stmt, SwitchValue):
74+
csrs |= get_statement_csrs(stmt.test)
75+
for pattern, value in stmt.cases:
76+
csrs |= get_statement_csrs(value)
77+
elif isinstance(stmt, (ClockSignal, ResetSignal, Initial)):
78+
pass
79+
elif isinstance(stmt, _Csr):
80+
csrs.add(stmt)
81+
elif isinstance(stmt, (Signal, Const, ValueCastable)):
82+
pass
83+
else:
84+
raise AssertionError("unknown object {} of type {} in statement", stmt, type(stmt))
85+
return csrs
86+
87+
fragment_signals = set()
88+
for _domain, statements in fragment.statements.items():
89+
for stmt in statements:
90+
fragment_signals |= get_statement_csrs(stmt)
91+
92+
csr_signals += [
93+
(signal.name, signal) for signal in fragment_signals
94+
if isinstance(signal, _Csr)
95+
and signal.name != "$signal"
96+
and not any(signal is cmp_signal for name, cmp_signal in csr_signals)
97+
]
98+
99+
new_csr_signals = [(name, signal) for name, signal in csr_signals if not any(signal is done for done in already_done)]
100+
old_csr_signals = [(name, signal) for name, signal in csr_signals if any(signal is done for done in already_done)]
101+
for name, signal in new_csr_signals:
102+
already_done.append(signal)
103+
104+
mmap = fragment.memorymap = MemoryMap()
105+
106+
if new_csr_signals:
107+
m = Module()
108+
csr_bank = m.submodules.csr_bank = CsrBank(names)
109+
print(f"-> adding csr bank for {'.'.join(names)}")
110+
for name, signal in new_csr_signals:
111+
if isinstance(signal, (ControlSignal, StatusSignal, EventReg)):
112+
csr_bank.reg(name, signal)
113+
signal._MustUse__used = True
114+
115+
mmap.allocate_subrange(csr_bank.memorymap, name=None) # name=None means that the Memorymap will be inlined
116+
platform.to_inject_subfragments.append((m, "ignore"))
117+
118+
for name, signal in old_csr_signals:
119+
mmap.add_alias(name, signal)
120+
121+
for name, driver_item in driver_items:
122+
fragment.memorymap.add_driver_item(name, driver_item)
116123

117124
for subfragment, name, _src_loc in fragment.subfragments:
118125
if isinstance(subfragment, RequirePosedge):
119126
continue
120-
inner(subfragment)
121-
inner(top_fragment)
127+
inner(subfragment, [*names, str(name)])
128+
inner(top_fragment, ["top"])
122129

123130

124-
def address_assignment_hook(platform, top_fragment: Fragment, sames: ElaboratableSames):
125-
def inner(fragment):
126-
module = sames.get_module(fragment)
131+
def address_assignment_hook(platform, top_fragment: Fragment):
132+
def inner(fragment, names):
133+
module = get_module(fragment)
127134
if hasattr(module, "peripheral"): # we have the fragment of a marker module for a peripheral
128135
fragment.memorymap = module.peripheral.memorymap
129136
return
@@ -132,17 +139,18 @@ def inner(fragment):
132139
for sub_fragment, sub_name, _src_loc in fragment.subfragments:
133140
if isinstance(sub_fragment, RequirePosedge) or sub_name == "ignore":
134141
continue
135-
inner(sub_fragment)
142+
inner(sub_fragment, [*names, str(sub_name)])
136143

137144
# add everything to the own memorymap
138145
if not hasattr(fragment, "memorymap"):
139146
fragment.memorymap = MemoryMap()
147+
print(f"-> assigning address for {'.'.join(names)}")
140148
for sub_fragment, sub_name, _src_loc in fragment.subfragments:
141149
if isinstance(sub_fragment, RequirePosedge) or sub_name == "ignore":
142150
continue
143151
assert hasattr(sub_fragment, "memorymap") # this holds because we did depth first recursion
144152
fragment.memorymap.allocate_subrange(sub_fragment.memorymap, sub_name)
145-
inner(top_fragment)
153+
inner(top_fragment, ["top"])
146154

147155
# prepare and finalize the memorymap
148156
top_memorymap: MemoryMap = top_fragment.memorymap
@@ -173,20 +181,25 @@ def get_child_bits(memorymap: MemoryMap):
173181
platform.memorymap = top_memorymap
174182

175183

176-
def peripherals_collect_hook(platform, top_fragment: Fragment, sames: ElaboratableSames):
184+
def peripherals_collect_hook(platform, top_fragment: Fragment):
177185
platform.peripherals = []
178186

179-
def collect_peripherals(platform, fragment: Fragment, sames):
180-
module = sames.get_module(fragment)
187+
def collect_peripherals(platform, fragment: Fragment, names):
188+
module = get_module(fragment)
181189
if module:
182190
if hasattr(module, "peripheral"):
191+
try:
192+
for elab in get_elaboratable(fragment) or ():
193+
print(f"-> collected peripheral {elab.name}")
194+
except:
195+
print(f"-> collected peripheral at {'.'.join(names)}")
183196
platform.peripherals.append(module.peripheral)
184197
for f, name, _src_loc in fragment.subfragments:
185198
if isinstance(f, RequirePosedge):
186199
continue
187-
collect_peripherals(platform, f, sames)
200+
collect_peripherals(platform, f, [*names, str(name)])
188201

189-
collect_peripherals(platform, top_fragment, sames)
202+
collect_peripherals(platform, top_fragment, ["top"])
190203

191204
ranges = [(peripheral.range(), peripheral) for peripheral in platform.peripherals
192205
if not peripheral.memorymap.is_empty and not peripheral.memorymap.was_inlined]

naps/soc/peripheral.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(
2222
self,
2323
handle_read: HandleRead,
2424
handle_write: HandleWrite,
25-
memorymap: MemoryMap
25+
memorymap: MemoryMap,
26+
name: str | None = None
2627
):
2728
"""
2829
A `Peripheral` is a thing that is memorymaped in the SOC.
@@ -38,6 +39,7 @@ def __init__(
3839
self.handle_read = handle_read
3940
self.handle_write = handle_write
4041
self.memorymap = memorymap
42+
self.name = name
4143

4244
def range(self):
4345
return self.memorymap.absolute_range_of_direct_children.range()

naps/soc/platform/jtag/jtag_soc_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, platform):
2727
self.jtag_active = Signal()
2828
self.jtag_debug_signals = Signal(32)
2929

30-
def peripherals_connect_hook(platform, top_fragment: Fragment, sames):
30+
def peripherals_connect_hook(platform, top_fragment: Fragment):
3131
from naps import JTAGPeripheralConnector
3232
if platform.peripherals:
3333
aggregator = PeripheralsAggregator()

naps/soc/platform/sim/sim_soc_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, platform):
2020
assert isinstance(platform, SimPlatform)
2121
super().__init__(platform)
2222

23-
def peripherals_connect_hook(platform, top_fragment: Fragment, sames):
23+
def peripherals_connect_hook(platform, top_fragment: Fragment):
2424
from naps.cores.axi import AxiEndpoint, AxiLitePeripheralConnector
2525

2626
if platform.peripherals:

naps/soc/platform/zynq/zynq_soc_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, platform, use_axi_interconnect=False):
2424
self.ps7 = PS7(here_is_the_only_place_that_instanciates_ps7=True)
2525
self.final_to_inject_subfragments.append((self.ps7, "ps7"))
2626

27-
def peripherals_connect_hook(platform, top_fragment: Fragment, sames):
27+
def peripherals_connect_hook(platform, top_fragment: Fragment):
2828
from naps.cores.axi import AxiEndpoint, AxiLitePeripheralConnector, AxiFullToLiteBridge, AxiInterconnect
2929

3030
if platform.peripherals:

0 commit comments

Comments
 (0)