Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions janus/pyc/janus/cube/CUBE_V2_SPEC.md
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,28 @@ Performance metrics:
- At 1 GHz: 4.096 TMAC/s (INT16)
```

### 7.6 Benchmark Results (64×64×64 MATMUL)

Actual cycle counts measured via Verilator simulation:

| PE Array | Tile Size | Tiles (M×K×N) | Uops | Theoretical | Actual | Overhead | Efficiency |
|----------|-----------|---------------|------|-------------|--------|----------|------------|
| 16×16 | 16×16 | 4×4×4 | 64 | 67 | 74 | 7 | 90.54% |
| 8×8 | 8×8 | 8×8×8 | 512 | 515 | 579 | 64 | 88.95% |
| 4×4 | 4×4 | 16×16×16 | 4096 | 4099 | 4163 | 64 | 98.46% |

```
Theoretical cycles = uops + pipeline_depth - 1 + startup_overhead
- 16×16: 64 + 4 - 1 = 67 (actual: 74, +7 overhead)
- 8×8: 512 + 4 - 1 = 515 (actual: 579, +64 overhead)
- 4×4: 4096 + 4 - 1 = 4099 (actual: 4163, +64 overhead)

Efficiency = theoretical / actual
- Larger PE arrays have higher per-uop throughput but more startup overhead
- Smaller PE arrays have lower overhead percentage due to more uops
- Fixed overhead (~64 cycles) from pipeline startup/drain and FSM transitions
```

---

## 8. MMIO Interface
Expand Down
10 changes: 10 additions & 0 deletions janus/pyc/janus/cube/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ uop4: [C0]──[C1]──[C2]──[C3]──►ACC
Pipeline: 4-cycle latency, 1 uop/cycle throughput
```

### Benchmark Results (64×64×64 MATMUL)

| PE Array | Uops | Actual Cycles | Efficiency |
|----------|------|---------------|------------|
| 16×16 | 64 | 74 | 90.54% |
| 8×8 | 512 | 579 | 88.95% |
| 4×4 | 4096 | 4163 | 98.46% |

See [CUBE_V2_SPEC.md](CUBE_V2_SPEC.md#76-benchmark-results-64×64×64-matmul) for detailed analysis.

### Cube v2 File Structure

```
Expand Down
2 changes: 1 addition & 1 deletion janus/pyc/janus/cube/cube_v2_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# =============================================================================
# Array Dimensions
# =============================================================================
ARRAY_SIZE = 16 # 16×16 systolic array
ARRAY_SIZE = 8 # 8×8 systolic array

# =============================================================================
# Buffer Sizes
Expand Down
102 changes: 67 additions & 35 deletions janus/pyc/janus/cube/cube_v2_decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Cube v2 MATMUL Decoder and Uop Generator.

Decomposes MATMUL(M, K, N) instructions into micro-operations (uops) for the systolic array.
Each uop represents a 16×16 tile multiplication.
Each uop represents an ARRAY_SIZE×ARRAY_SIZE tile multiplication.
"""

from __future__ import annotations
Expand Down Expand Up @@ -83,22 +83,26 @@ def build_matmul_decoder(
gen_state = _make_uop_gen_state(m, clk, rst, consts)

# Calculate tile counts on start
# tiles = ceil(dim / 16) = (dim + 15) / 16
# tiles = ceil(dim / ARRAY_SIZE) = (dim + ARRAY_SIZE - 1) / ARRAY_SIZE
# Use bit shift for power-of-2 ARRAY_SIZE
import math
shift_amount = int(math.log2(ARRAY_SIZE))

with m.scope("TILE_CALC"):
tile_size = c(ARRAY_SIZE, width=16)
tile_mask = c(ARRAY_SIZE - 1, width=16)

# M tiles
m_plus = inst_m + tile_mask
m_tiles_calc = m_plus >> 4 # Divide by 16
m_tiles_calc = m_plus >> shift_amount

# K tiles
k_plus = inst_k + tile_mask
k_tiles_calc = k_plus >> 4
k_tiles_calc = k_plus >> shift_amount

# N tiles
n_plus = inst_n + tile_mask
n_tiles_calc = n_plus >> 4
n_tiles_calc = n_plus >> shift_amount

# Latch instruction on start
with m.scope("LATCH"):
Expand All @@ -110,14 +114,8 @@ def build_matmul_decoder(
gen_state.k_tiles.set(k_tiles_calc.trunc(width=TILE_IDX_WIDTH), when=start)
gen_state.n_tiles.set(n_tiles_calc.trunc(width=TILE_IDX_WIDTH), when=start)

# Reset tile indices
gen_state.m_tile.set(c(0, width=TILE_IDX_WIDTH), when=start)
gen_state.k_tile.set(c(0, width=TILE_IDX_WIDTH), when=start)
gen_state.n_tile.set(c(0, width=TILE_IDX_WIDTH), when=start)

# Start generating
gen_state.generating.set(consts.one1, when=start)
gen_state.gen_done.set(consts.zero1, when=start)
# Note: tile indices are set below with explicit priority mux
# Note: generating and gen_done are set below with explicit priority

# Uop generation logic
with m.scope("UOP_GEN"):
Expand Down Expand Up @@ -157,7 +155,7 @@ def build_matmul_decoder(
# Output valid uop
uop_valid = can_generate

# Advance tile indices (iterate: k, n, m order for better locality)
# Compute tile index advancement (iterate: k, n, m order for better locality)
with m.scope("ADVANCE"):
# Next k_tile
k_tile_next = k_tile + c(1, width=TILE_IDX_WIDTH)
Expand All @@ -174,29 +172,63 @@ def build_matmul_decoder(
# All done when m wraps
all_done = k_wrap & n_wrap & m_wrap

# Update indices when generating
# K advances every cycle
# Compute new values for tile indices
new_k = k_wrap.select(c(0, width=TILE_IDX_WIDTH), k_tile_next)
gen_state.k_tile.set(new_k, when=can_generate)

# N advances when K wraps
new_n = (k_wrap & n_wrap).select(c(0, width=TILE_IDX_WIDTH), n_tile_next)
gen_state.n_tile.set(new_n, when=can_generate & k_wrap)

# M advances when N wraps
gen_state.m_tile.set(m_tile_next, when=can_generate & k_wrap & n_wrap)

# Done when all tiles generated
gen_state.generating.set(consts.zero1, when=can_generate & all_done)
gen_state.gen_done.set(consts.one1, when=can_generate & all_done)

# Reset logic
with m.scope("RESET"):
gen_state.generating.set(consts.zero1, when=reset_decoder)
gen_state.gen_done.set(consts.zero1, when=reset_decoder)
gen_state.m_tile.set(c(0, width=TILE_IDX_WIDTH), when=reset_decoder)
gen_state.k_tile.set(c(0, width=TILE_IDX_WIDTH), when=reset_decoder)
gen_state.n_tile.set(c(0, width=TILE_IDX_WIDTH), when=reset_decoder)

# Explicit priority mux for generating and gen_done
# Priority: reset_decoder > (can_generate & all_done) > start > hold
with m.scope("STATE_UPDATE"):
current_generating = gen_state.generating.out()
current_gen_done = gen_state.gen_done.out()

# Default: hold current value
next_generating = current_generating
next_gen_done = current_gen_done

# start sets generating=1, gen_done=0
next_generating = start.select(consts.one1, next_generating)
next_gen_done = start.select(consts.zero1, next_gen_done)

# can_generate & all_done sets generating=0, gen_done=1
finish_cond = can_generate & all_done
next_generating = finish_cond.select(consts.zero1, next_generating)
next_gen_done = finish_cond.select(consts.one1, next_gen_done)

# reset_decoder sets generating=0, gen_done=0 (highest priority)
next_generating = reset_decoder.select(consts.zero1, next_generating)
next_gen_done = reset_decoder.select(consts.zero1, next_gen_done)

# Single set call with explicit next value
gen_state.generating.set(next_generating)
gen_state.gen_done.set(next_gen_done)

# Explicit priority mux for tile indices
# Priority: reset_decoder > start > advance > hold
with m.scope("TILE_UPDATE"):
# K tile
current_k = gen_state.k_tile.out()
next_k = current_k
next_k = can_generate.select(new_k, next_k)
next_k = start.select(c(0, width=TILE_IDX_WIDTH), next_k)
next_k = reset_decoder.select(c(0, width=TILE_IDX_WIDTH), next_k)
gen_state.k_tile.set(next_k)

# N tile
current_n = gen_state.n_tile.out()
next_n_val = current_n
next_n_val = (can_generate & k_wrap).select(new_n, next_n_val)
next_n_val = start.select(c(0, width=TILE_IDX_WIDTH), next_n_val)
next_n_val = reset_decoder.select(c(0, width=TILE_IDX_WIDTH), next_n_val)
gen_state.n_tile.set(next_n_val)

# M tile
current_m = gen_state.m_tile.out()
next_m = current_m
next_m = (can_generate & k_wrap & n_wrap).select(m_tile_next, next_m)
next_m = start.select(c(0, width=TILE_IDX_WIDTH), next_m)
next_m = reset_decoder.select(c(0, width=TILE_IDX_WIDTH), next_m)
gen_state.m_tile.set(next_m)

gen_done = gen_state.gen_done.out()

Expand Down
110 changes: 73 additions & 37 deletions janus/pyc/janus/cube/cube_v2_issue_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,30 +90,26 @@ def build_issue_queue(
queue_full = count.out().eq(c(ISSUE_QUEUE_SIZE, width=QUEUE_IDX_WIDTH + 1))
queue_empty = count.out().eq(c(0, width=QUEUE_IDX_WIDTH + 1))

# Enqueue logic
# Enqueue logic - compute enqueue conditions
with m.scope("ENQUEUE"):
can_enqueue = enqueue_valid & ~queue_full & ~flush

# Compute per-entry enqueue conditions
enqueue_this_list = []
for i in range(ISSUE_QUEUE_SIZE):
tail_match = tail.out().eq(c(i, width=QUEUE_IDX_WIDTH))
enqueue_this = can_enqueue & tail_match
enqueue_this_list.append(enqueue_this)

# Write uop data
# Write uop data (these don't have conflicts)
entries[i].uop.l0a_idx.set(enqueue_l0a_idx, when=enqueue_this)
entries[i].uop.l0b_idx.set(enqueue_l0b_idx, when=enqueue_this)
entries[i].uop.acc_idx.set(enqueue_acc_idx, when=enqueue_this)
entries[i].uop.is_first.set(enqueue_is_first, when=enqueue_this)
entries[i].uop.is_last.set(enqueue_is_last, when=enqueue_this)

# Set valid, clear issued
entries[i].valid.set(consts.one1, when=enqueue_this)
entries[i].issued.set(consts.zero1, when=enqueue_this)

# Update tail pointer
next_tail = (tail.out() + consts.one8.trunc(width=QUEUE_IDX_WIDTH)) & c(
ISSUE_QUEUE_SIZE - 1, width=QUEUE_IDX_WIDTH
)
tail.set(next_tail, when=can_enqueue)
# Note: valid and issued updates moved to ENTRY_STATE section
# Note: tail pointer update moved to FLUSH section with explicit priority mux

# Update ready bits based on buffer status
with m.scope("READY_UPDATE"):
Expand Down Expand Up @@ -182,12 +178,13 @@ def build_issue_queue(

found = found | is_ready

# Mark as issued when acknowledged
# Compute mark_issued conditions (moved to ENTRY_STATE section)
issue_and_ack = issue_valid & issue_ack
mark_issued_list = []
for i in range(ISSUE_QUEUE_SIZE):
idx_match = issue_idx.eq(c(i, width=QUEUE_IDX_WIDTH))
mark_issued = issue_and_ack & idx_match
entries[i].issued.set(consts.one1, when=mark_issued)
mark_issued_list.append(mark_issued)

# Create issue result
issued_uop = Uop(
Expand All @@ -199,15 +196,14 @@ def build_issue_queue(
)
issue_result = IssueResult(issue_valid=issue_valid, uop=issued_uop)

# Retire logic (remove completed entries)
# Retire logic (compute retire conditions)
with m.scope("RETIRE"):
# Retire from head when issued
# Compute can_retire conditions
can_retire_list = []
for i in range(ISSUE_QUEUE_SIZE):
head_match = head.out().eq(c(i, width=QUEUE_IDX_WIDTH))
can_retire = head_match & entries[i].valid.out() & entries[i].issued.out()

# Clear entry
entries[i].valid.set(consts.zero1, when=can_retire)
can_retire_list.append(can_retire)

# Update head pointer when retiring
head_entry_issued = consts.zero1
Expand All @@ -218,32 +214,72 @@ def build_issue_queue(
head_entry_issued,
)

next_head = (head.out() + consts.one8.trunc(width=QUEUE_IDX_WIDTH)) & c(
ISSUE_QUEUE_SIZE - 1, width=QUEUE_IDX_WIDTH
)
head.set(next_head, when=head_entry_issued)
# Note: head pointer update moved to FLUSH section with explicit priority mux

# Entry state updates with explicit priority mux
# This consolidates all valid and issued updates to avoid multiple continuous assignments
with m.scope("ENTRY_STATE"):
for i in range(ISSUE_QUEUE_SIZE):
# Valid: Priority: flush > retire > enqueue > hold
current_valid = entries[i].valid.out()
next_valid = current_valid
next_valid = enqueue_this_list[i].select(consts.one1, next_valid)
next_valid = can_retire_list[i].select(consts.zero1, next_valid)
next_valid = flush.select(consts.zero1, next_valid)
entries[i].valid.set(next_valid)

# Issued: Priority: enqueue (clear) > mark_issued (set) > hold
current_issued = entries[i].issued.out()
next_issued = current_issued
next_issued = mark_issued_list[i].select(consts.one1, next_issued)
next_issued = enqueue_this_list[i].select(consts.zero1, next_issued)
entries[i].issued.set(next_issued)

# Update count
with m.scope("COUNT"):
enqueued = can_enqueue
retired = head_entry_issued

next_count = count.out()
# Increment on enqueue
next_count = enqueued.select(next_count + c(1, width=QUEUE_IDX_WIDTH + 1), next_count)
# Decrement on retire
next_count = retired.select(next_count - c(1, width=QUEUE_IDX_WIDTH + 1), next_count)
# Explicit priority mux for count
# Priority: flush > (enqueue/retire) > hold
current_count = count.out()
next_count = current_count

count.set(next_count, when=enqueued | retired)

# Flush logic
with m.scope("FLUSH"):
for i in range(ISSUE_QUEUE_SIZE):
entries[i].valid.set(consts.zero1, when=flush)

head.set(c(0, width=QUEUE_IDX_WIDTH), when=flush)
tail.set(c(0, width=QUEUE_IDX_WIDTH), when=flush)
count.set(c(0, width=QUEUE_IDX_WIDTH + 1), when=flush)
# Increment on enqueue (lower priority)
next_count = enqueued.select(current_count + c(1, width=QUEUE_IDX_WIDTH + 1), next_count)
# Decrement on retire (same priority level, can happen simultaneously)
next_count = retired.select(next_count - c(1, width=QUEUE_IDX_WIDTH + 1), next_count)
# Flush resets to 0 (highest priority)
next_count = flush.select(c(0, width=QUEUE_IDX_WIDTH + 1), next_count)

# Single set call
count.set(next_count)

# Pointer updates with explicit priority mux
with m.scope("PTRS_UPDATE"):
# Explicit priority mux for head and tail
# Priority: flush > normal update > hold
current_head = head.out()
next_head_val = current_head
next_head_val = head_entry_issued.select(
(current_head + consts.one8.trunc(width=QUEUE_IDX_WIDTH)) & c(
ISSUE_QUEUE_SIZE - 1, width=QUEUE_IDX_WIDTH
),
next_head_val,
)
next_head_val = flush.select(c(0, width=QUEUE_IDX_WIDTH), next_head_val)
head.set(next_head_val)

current_tail = tail.out()
next_tail_val = current_tail
next_tail_val = can_enqueue.select(
(current_tail + consts.one8.trunc(width=QUEUE_IDX_WIDTH)) & c(
ISSUE_QUEUE_SIZE - 1, width=QUEUE_IDX_WIDTH
),
next_tail_val,
)
next_tail_val = flush.select(c(0, width=QUEUE_IDX_WIDTH), next_tail_val)
tail.set(next_tail_val)

entries_used = count.out()

Expand Down
17 changes: 13 additions & 4 deletions janus/pyc/janus/cube/cube_v2_l0_reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,21 @@ def build_l0_buffer_reuse(
loading_reg = m.out("loading", clk=clk, rst=rst, width=1, init=0, en=consts.one1)
ref_count_reg = m.out("ref_count", clk=clk, rst=rst, width=8, init=0, en=consts.one1)

# Create a valid register that mirrors the instance output
valid_reg = m.out("valid", clk=clk, rst=rst, width=1, init=0, en=consts.one1)
valid_reg.set(entry["valid"], when=consts.one1)
# Use the instance's valid output directly (it's already registered)
# Create a dummy register that just holds the value for the status interface
valid_wire = entry["valid"]

# Create a simple wrapper that exposes the valid signal
# We use a register but set it unconditionally to the instance output
# This avoids the extra cycle of latency
class ValidWrapper:
def __init__(self, wire):
self._wire = wire
def out(self):
return self._wire

status = L0EntryStatus(
valid=valid_reg,
valid=ValidWrapper(valid_wire),
loading=loading_reg,
ref_count=ref_count_reg,
)
Expand Down
Loading
Loading