Skip to content

Commit 3b4b7eb

Browse files
authored
Add masked SRAM write support
2 parents 4163564 + 036e726 commit 3b4b7eb

File tree

6 files changed

+183
-6
lines changed

6 files changed

+183
-6
lines changed

python/assassyn/ir/memory/base.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Base class for memory modules that provides common functionality for SRAM and DR
3333
- `re: Value` - Read enable signal (combinational input)
3434
- `addr: Value` - Address signal (combinational input)
3535
- `wdata: Value` - Write data signal (combinational input)
36+
- `wmask: Value | None` - Optional per-bit write mask for partial writes
3637
- `addr_width: int` - Width of the address in bits (derived as log2(depth))
3738
- `_payload: Array` - Array holding the memory contents (private, not for direct access, owned by the memory instance)
3839

@@ -48,7 +49,7 @@ Initialize memory base class with validation and setup.
4849
**Returns:** None
4950

5051
**Explanation:**
51-
This constructor validates all input parameters and sets up the memory module infrastructure. It enforces that depth must be a power of 2 to enable efficient address decoding using log2 operations. The constructor creates a `RegArray` instance with the specified width and depth to serve as the memory payload, using the instance name for proper identification in generated code. All signal attributes are initialized to None and will be assigned during the `build()` method of concrete implementations.
52+
This constructor validates all input parameters and sets up the memory module infrastructure. It enforces that depth must be a power of 2 to enable efficient address decoding using log2 operations. The constructor creates a `RegArray` instance with the specified width and depth to serve as the memory payload, using the instance name for proper identification in generated code. All signal attributes are initialized to None and will be assigned during the `build()` method of concrete implementations. `wmask` defaults to `None`, so existing memories keep the original full-word write behavior unless masked writes are explicitly requested.
5253

5354
**Address Width Derivation Logic:** The address width is calculated as `addr_width = log2(depth)` because:
5455
1. **Power-of-2 Constraint**: Depth must be a power of 2 to enable efficient address decoding

python/assassyn/ir/memory/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class MemoryBase(Downstream):
2424
re: Value # Read enable signal
2525
addr: Value # Address signal
2626
wdata: Value # Write data signal
27+
wmask: Value | None # Optional write mask signal
2728

2829
# Derived Values
2930
addr_width: int # Width of the address in bits
@@ -70,3 +71,4 @@ def __init__(self, width: int, depth: int, init_file: str | None):
7071
self.re = None
7172
self.addr = None
7273
self.wdata = None
74+
self.wmask = None

python/assassyn/ir/memory/sram.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Initialize SRAM module with read data buffer.
4444
**Explanation:**
4545
This constructor calls the parent `MemoryBase.__init__()` method to inherit base memory functionality, then creates an additional `dout` register buffer. Both arrays record the SRAM instance as their owner: `_payload` is created in the base class with `owner=self`, and `dout` uses the same override. Downstream passes rely on `Array.is_payload(SRAM)` to distinguish the payload buffer from auxiliary registers. Using `Bits` type ensures compatibility with array read operations that return raw bit values.
4646

47-
### `def build(self, we, re, addr, wdata)`
47+
### `def build(self, we, re, addr, wdata, wmask=None)`
4848

4949
Build the SRAM module with combinational logic for synchronous memory operations.
5050

@@ -53,6 +53,7 @@ Build the SRAM module with combinational logic for synchronous memory operations
5353
- `re: Value` - Read enable signal
5454
- `addr: Value` - Address signal
5555
- `wdata: Value` - Write data signal
56+
- `wmask: Value | None` - Optional per-bit write mask. `None` preserves full-word writes.
5657

5758
**Returns:** None
5859

@@ -61,7 +62,7 @@ This method implements the core SRAM functionality using the `@combinational` de
6162

6263
1. **Signal Assignment:** All input signals are stored as instance attributes for memory operations
6364
2. **Mutual Exclusion:** Uses `assume(~(we & re))` from [intrinsic.py](../expr/intrinsic.md) to enforce that read and write operations cannot be enabled simultaneously
64-
3. **Write Operation:** When `we` is enabled, writes `wdata` to `_payload[addr]` using conditional execution
65+
3. **Write Operation:** When `we` is enabled, either writes `wdata` directly to `_payload[addr]` or merges it with the previous word using `wmask`
6566
4. **Read Operation:** When `re` is enabled, reads `_payload[addr]` and stores the result in `dout[0]` for downstream modules to access
6667

6768
**SRAM Read Data Timing:** The relationship between read enable timing and `dout` buffer update:
@@ -73,6 +74,7 @@ This method implements the core SRAM functionality using the `@combinational` de
7374
**Technical Details:**
7475
- Uses `Condition` blocks for conditional execution of read/write operations
7576
- Enforces mutual exclusion between read and write operations using `assume` intrinsic
77+
- Supports optional masked writes with `merged = (old_data & ~wmask) | (wdata & wmask)`
7678
- Provides immediate data access without request/response cycles
7779
- Read data is buffered in `dout` register for downstream module consumption
7880
- Follows the combinational downstream module pattern for same-cycle signal processing

python/assassyn/ir/memory/sram.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,32 @@ def __init__(self, width: int, depth: int, init_file: str | None):
3434
)
3535

3636
@combinational
37-
def build(self, we, re, addr, wdata): # pylint: disable=too-many-arguments
37+
def build(self, we, re, addr, wdata, wmask=None): # pylint: disable=too-many-arguments
3838
'''The constructor for the SRAM module.
3939
4040
Args:
4141
we: Value: The write enable signal.
4242
re: Value: The read enable signal.
4343
addr: Value: The address signal.
4444
wdata: Value: The write data signal.
45+
wmask: Optional per-bit write mask. `None` keeps full-word writes.
4546
'''
4647
self.we = we
4748
self.re = re
4849
self.addr = addr
4950
self.wdata = wdata
51+
self.wmask = wmask
5052

5153
# Enforce that we and re cannot be both enabled
5254
assume(~(we & re))
5355

5456
with Condition(we):
55-
self._payload[addr] = wdata
57+
if wmask is None:
58+
self._payload[addr] = wdata
59+
else:
60+
old_data = self._payload[addr]
61+
merged = (old_data & (~self.wmask)) | (wdata & self.wmask)
62+
self._payload[addr] = merged
5663
with Condition(re):
5764
self.dout[0] = self._payload[addr]
5865

python/ci-tests/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
| `test_select, test_select1hot` | gramme : `select` |
1616
| `test_testbench` | Usage of `with Cycle(1):` |
1717
| `test_explict_pop, test_peek` | gramme in `Port` |
18+
| `test_sram, test_sram_masked` | SRAM full-word access and masked write |
1819
| | |
1920
| `test_fifo1, test_bind, `<br>`test_eager_bind, test_imbalance, `<br>`test_fifo_valid, test_wait_until` | sth about **Pure Sequential Logic** |
2021
| `test_comb_expose, test_toposort`<br />`test_downstream, ` | sth about **Pure Combinational Logic** |
@@ -63,4 +64,7 @@
6364
14. `test_explict_pop`
6465
+ An alternative method for reading port data.
6566
15. `test_peek`
66-
+ Similar to the operation of viewing the top of a queue in a `Queue`. It corresponds to the `front()` operation in the STL of C++ queues. Essentially, it is looking at the top element of the queue without removing it.
67+
+ Similar to the operation of viewing the top of a queue in a `Queue`. It corresponds to the `front()` operation in the STL of C++ queues. Essentially, it is looking at the top element of the queue without removing it.
68+
16. `test_sram, test_sram_masked`
69+
+ `test_sram` covers the original full-word SRAM read/write behavior.
70+
+ `test_sram_masked` adds regression coverage for masked writes, including zero mask, each byte lane, both halfword lanes, a non-byte-aligned bit mask, and cross-address isolation.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""CI test for SRAM masked-write support."""
2+
3+
import re
4+
5+
from assassyn.frontend import *
6+
from assassyn.test import run_test
7+
8+
9+
ADDR0 = 0x12
10+
ADDR1 = 0x34
11+
12+
OPS = [
13+
{"kind": "write", "addr": ADDR0, "wdata": 0x11223344, "wmask": 0xFFFFFFFF},
14+
{"kind": "read", "addr": ADDR0, "expected": 0x11223344},
15+
{"kind": "write", "addr": ADDR0, "wdata": 0xFFFFFFFF, "wmask": 0x00000000},
16+
{"kind": "read", "addr": ADDR0, "expected": 0x11223344},
17+
{"kind": "write", "addr": ADDR0, "wdata": 0x000000AA, "wmask": 0x000000FF},
18+
{"kind": "read", "addr": ADDR0, "expected": 0x112233AA},
19+
{"kind": "write", "addr": ADDR0, "wdata": 0x0000BB00, "wmask": 0x0000FF00},
20+
{"kind": "read", "addr": ADDR0, "expected": 0x1122BBAA},
21+
{"kind": "write", "addr": ADDR0, "wdata": 0x00CC0000, "wmask": 0x00FF0000},
22+
{"kind": "read", "addr": ADDR0, "expected": 0x11CCBBAA},
23+
{"kind": "write", "addr": ADDR0, "wdata": 0xDD000000, "wmask": 0xFF000000},
24+
{"kind": "read", "addr": ADDR0, "expected": 0xDDCCBBAA},
25+
{"kind": "write", "addr": ADDR0, "wdata": 0x0000EEFF, "wmask": 0x0000FFFF},
26+
{"kind": "read", "addr": ADDR0, "expected": 0xDDCCEEFF},
27+
{"kind": "write", "addr": ADDR0, "wdata": 0xA1B20000, "wmask": 0xFFFF0000},
28+
{"kind": "read", "addr": ADDR0, "expected": 0xA1B2EEFF},
29+
{"kind": "write", "addr": ADDR1, "wdata": 0x55667788, "wmask": 0xFFFFFFFF},
30+
{"kind": "read", "addr": ADDR1, "expected": 0x55667788},
31+
{"kind": "write", "addr": ADDR1, "wdata": 0x00990000, "wmask": 0x00FF0000},
32+
{"kind": "read", "addr": ADDR1, "expected": 0x55997788},
33+
{"kind": "write", "addr": ADDR1, "wdata": 0xAA5500CC, "wmask": 0x0F0F00F0},
34+
{"kind": "read", "addr": ADDR1, "expected": 0x5A9577C8},
35+
{"kind": "read", "addr": ADDR0, "expected": 0xA1B2EEFF},
36+
]
37+
38+
39+
class ReadObserver(Module):
40+
41+
def __init__(self):
42+
super().__init__(
43+
ports={
44+
"step": Port(Bits(8)),
45+
"addr": Port(Bits(9)),
46+
}
47+
)
48+
49+
@module.combinational
50+
def build(self, rdata: RegArray):
51+
step, addr = self.pop_all_ports(True)
52+
log(
53+
"masked_read step={} addr=0x{:03x} data=0x{:08x}",
54+
step,
55+
addr,
56+
rdata[0].bitcast(Bits(32)),
57+
)
58+
59+
60+
class Launcher(Module):
61+
62+
def __init__(self, target):
63+
super().__init__(ports={})
64+
self.target = target
65+
66+
@module.combinational
67+
def build(self):
68+
self.target.async_called()
69+
70+
71+
class MaskedDriver(Module):
72+
73+
def __init__(self, observer):
74+
super().__init__(ports={})
75+
self.name = "Driver"
76+
self.observer = observer
77+
78+
@module.combinational
79+
def build(self):
80+
phase_bits = max(1, (len(OPS) + 1).bit_length())
81+
phase = RegArray(UInt(phase_bits), 1, initializer=[0])
82+
state = phase[0]
83+
next_state = state + UInt(phase_bits)(1)
84+
(phase & self)[0] <= next_state
85+
86+
we = Bits(1)(0)
87+
re = Bits(1)(0)
88+
addr = Bits(9)(0)
89+
wdata = Bits(32)(0)
90+
wmask = Bits(32)(0)
91+
92+
for idx, op in enumerate(OPS):
93+
is_step = state == UInt(phase_bits)(idx)
94+
addr_bits = Bits(9)(op["addr"])
95+
addr = is_step.select(addr_bits, addr)
96+
97+
if op["kind"] == "write":
98+
we = is_step.select(Bits(1)(1), we)
99+
wdata = is_step.select(Bits(32)(op["wdata"]), wdata)
100+
wmask = is_step.select(Bits(32)(op["wmask"]), wmask)
101+
else:
102+
re = is_step.select(Bits(1)(1), re)
103+
with Condition(is_step):
104+
self.observer.async_called(step=Bits(8)(idx), addr=addr_bits)
105+
106+
sram = SRAM(32, 512, None)
107+
sram.build(we, re, addr, wdata, wmask)
108+
109+
with Condition(state == UInt(phase_bits)(len(OPS))):
110+
finish()
111+
112+
return sram
113+
114+
115+
READ_RE = re.compile(
116+
r"masked_read step=(\d+) addr=0x([0-9a-fA-F]+) data=0x([0-9a-fA-F]+)"
117+
)
118+
119+
120+
def check(raw):
121+
expected_reads = [
122+
(idx, op["addr"], op["expected"])
123+
for idx, op in enumerate(OPS)
124+
if op["kind"] == "read"
125+
]
126+
127+
actual_reads = []
128+
for line in raw.splitlines():
129+
if "[readobserver" not in line.lower():
130+
continue
131+
match = READ_RE.search(line)
132+
assert match is not None, f"Unexpected ReadObserver log line: {line}"
133+
actual_reads.append(
134+
(
135+
int(match.group(1)),
136+
int(match.group(2), 16),
137+
int(match.group(3), 16),
138+
)
139+
)
140+
141+
assert actual_reads == expected_reads, (
142+
f"Masked SRAM reads mismatch.\n"
143+
f"expected={expected_reads}\n"
144+
f"actual={actual_reads}"
145+
)
146+
147+
148+
def test_sram_masked_write():
149+
def top():
150+
observer = ReadObserver()
151+
driver = MaskedDriver(observer)
152+
sram = driver.build()
153+
launcher = Launcher(driver)
154+
launcher.build()
155+
observer.build(sram.dout)
156+
157+
run_test("sram_masked", top, check, sim_threshold=200, idle_threshold=200)
158+
159+
160+
if __name__ == "__main__":
161+
test_sram_masked_write()

0 commit comments

Comments
 (0)