Skip to content

Commit c2d24d5

Browse files
authored
refactor(difftest): pipeline gateway stages with DecoupleIO (#802)
Convert the difftest gateway path to a Decoupled pipeline and add a reusable PipelineConnect stage register for staged transport between preprocess, validate, replay, squash, delta and batch. For clkgate/backpressure, update Delayer to advance only on real pipeline fire, and rework the FPGA difftest path to propagate host backpressure through DecoupledIO instead of a separate enable signal. Note we keep clockEnable for one extra cycle beyond ready, ensuring the valid signal is sampled correctly on the fire handshake. For delta, replace the old queue-based flow with staged splitters. And non-delta and delta bundles are handled separately: - non-delta bundles are only made valid for the single pipelined.fire cycle, even if the combined output remains valid. - delta bundles continue to drain through the splitters under backpressure. When the final visible delta beat is stalled, block new input for the following cycle so the current beat cannot be replaced before it is consumed. This change preserves correctness when delta emits long consecutive data and decouples stage logic for future optimization.
1 parent 64a81d5 commit c2d24d5

File tree

11 files changed

+354
-221
lines changed

11 files changed

+354
-221
lines changed

src/main/scala/Batch.scala

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import chisel3.util._
2020
import difftest._
2121
import difftest.gateway.GatewayConfig
2222
import difftest.common.DifftestPerf
23-
import difftest.util.LookupTree
23+
import difftest.util.{LookupTree, PipelineConnect}
2424

2525
import scala.collection.mutable.ListBuffer
2626

@@ -62,7 +62,6 @@ class BatchStats(param: BatchParam) extends Bundle {
6262

6363
class BatchOutput(param: BatchParam, config: GatewayConfig) extends Bundle {
6464
val io = new BatchIO(param)
65-
val enable = Bool()
6665
val step = UInt(config.stepWidth.W)
6766
}
6867

@@ -71,13 +70,21 @@ class BatchInfo extends Bundle {
7170
val num = UInt(8.W)
7271
}
7372

73+
class BatchStepResult(param: BatchParam, config: GatewayConfig) extends Bundle {
74+
val data = UInt(param.StepDataBitLen.W)
75+
val info = UInt(param.StepInfoBitLen.W)
76+
// status of step_data_head split in different loc
77+
val status = Vec(param.StepGroupSize, new BatchStats(param))
78+
val trace_info = Option.when(config.hasReplay)(new DiffTraceInfo(config))
79+
}
80+
7481
object Batch {
7582
private val template = ListBuffer.empty[DifftestBundle]
7683

77-
def apply(bundles: MixedVec[Valid[DifftestBundle]], config: GatewayConfig): BatchOutput = {
78-
template ++= chiselTypeOf(bundles).map(_.bits).distinctBy(_.desiredCppName)
79-
val module = Module(new BatchEndpoint(chiselTypeOf(bundles).toSeq, config))
80-
module.in := bundles
84+
def apply(bundles: DecoupledIO[MixedVec[Valid[DifftestBundle]]], config: GatewayConfig): DecoupledIO[BatchOutput] = {
85+
template ++= chiselTypeOf(bundles.bits).map(_.bits).distinctBy(_.desiredCppName)
86+
val module = Module(new BatchEndpoint(chiselTypeOf(bundles.bits).toSeq, config))
87+
module.in <> bundles
8188
module.out
8289
}
8390

@@ -89,29 +96,18 @@ object Batch {
8996
}
9097

9198
class BatchEndpoint(bundles: Seq[Valid[DifftestBundle]], config: GatewayConfig) extends Module {
92-
val in = IO(Input(MixedVec(bundles)))
93-
val param = BatchParam(config, in.map(_.bits).toSeq)
99+
val in = IO(Flipped(Decoupled(MixedVec(bundles))))
100+
val param = BatchParam(config, in.bits.map(_.bits).toSeq)
94101

95102
// Collect valid bundles of same cycle
96-
val collector = Module(new BatchCollector(bundles, param))
97-
collector.in := RegNext(in)
98-
val step_data = collector.step_data
99-
val step_info = collector.step_info
100-
val step_enable = collector.step_enable
101-
val step_status = collector.step_status
103+
val collector = Module(new BatchCollector(bundles, param, config))
104+
PipelineConnect(in, collector.in, collector.in.fire)
102105

103106
// Assemble collected data from different cycles
104107
val assembler = Module(new BatchAssembler(param, config))
105-
assembler.step_data := step_data
106-
assembler.step_info := step_info
107-
assembler.step_status := step_status
108-
assembler.step_enable := step_enable
109-
if (config.hasReplay) {
110-
val trace_info = in.map(_.bits).filter(_.desiredCppName == "trace_info").head.asInstanceOf[DiffTraceInfo]
111-
assembler.step_trace_info.get := trace_info
112-
}
113-
val out = IO(Output(new BatchOutput(param, config)))
114-
out := assembler.out
108+
assembler.in <> collector.out
109+
val out = IO(chiselTypeOf(assembler.out))
110+
out <> assembler.out
115111
}
116112

117113
// Cluster Data from same group in same cycle
@@ -148,51 +144,60 @@ class BatchCluster(bundleType: DifftestBundle, groupSize: Int, param: BatchParam
148144
}
149145

150146
// Collect Data from different group in same cycle
151-
class BatchCollector(bundles: Seq[Valid[DifftestBundle]], param: BatchParam) extends Module {
152-
val in = IO(Input(MixedVec(bundles)))
153-
val step_data = IO(Output(UInt(param.StepDataBitLen.W)))
154-
val step_info = IO(Output(UInt(param.StepInfoBitLen.W)))
155-
// status of step_data_head split in different loc
156-
val step_status = IO(Output(Vec(param.StepGroupSize, new BatchStats(param))))
157-
val step_enable = IO(Output(Bool()))
147+
class BatchCollector(bundles: Seq[Valid[DifftestBundle]], param: BatchParam, config: GatewayConfig) extends Module {
148+
val in = IO(Flipped(Decoupled(MixedVec(bundles))))
149+
val out = IO(Decoupled(new BatchStepResult(param, config)))
158150

159151
def getGroupDataWidth: Seq[Valid[DifftestBundle]] => Int = { group =>
160152
group.length * group.head.bits.getByteAlignWidth
161153
}
162154

163-
val in_group = in.groupBy(_.bits.desiredCppName).values
155+
val in_group = in.bits.groupBy(_.bits.desiredCppName).values
164156
val in_group_single = in_group.filter(_.size == 1).toSeq
165157
val in_group_multi = in_group.filterNot(_.size == 1).flatMap(_.grouped(8)).toSeq
166158
val sorted = in_group_single.sortBy(getGroupDataWidth).reverse ++ in_group_multi.sortBy(getGroupDataWidth)
167159

168160
// Stage 1: concat bundles with same desiredCppName
169-
val group_info = Wire(Vec(param.StepGroupSize, UInt(param.infoWidth.W)))
170-
val group_status = Wire(Vec(param.StepGroupSize, new BatchStats(param)))
171-
val group_data = Wire(MixedVec(sorted.map(getGroupDataWidth).map(group_w => UInt(group_w.W))))
161+
class GroupBundle extends Bundle {
162+
val data = MixedVec(sorted.map(getGroupDataWidth).map(group_w => UInt(group_w.W)))
163+
val info = Vec(param.StepGroupSize, UInt(param.infoWidth.W))
164+
val status = Vec(param.StepGroupSize, new BatchStats(param))
165+
val trace_info = Option.when(config.hasReplay)(new DiffTraceInfo(config))
166+
}
167+
val grouped = Wire(Decoupled(new GroupBundle))
168+
grouped.valid := in.valid
169+
in.ready := grouped.ready
172170

171+
grouped.bits.trace_info.foreach(
172+
_ := in.bits.map(_.bits).filter(_.desiredCppName == "trace_info").head.asInstanceOf[DiffTraceInfo]
173+
)
173174
sorted.zipWithIndex.foreach { case (v_gens, gid) =>
174175
val cluster = Module(new BatchCluster(chiselTypeOf(v_gens.head.bits), v_gens.length, param))
175176
cluster.in := v_gens
176-
val status_base = if (gid == 0) 0.U.asTypeOf(new BatchStats(param)) else group_status(gid - 1)
177+
val status_base = if (gid == 0) 0.U.asTypeOf(new BatchStats(param)) else grouped.bits.status(gid - 1)
177178
cluster.status_base := status_base
178-
group_data(gid) := cluster.out_data
179-
group_info(gid) := cluster.out_info
180-
group_status(gid) := cluster.status_sum
179+
grouped.bits.data(gid) := cluster.out_data
180+
grouped.bits.info(gid) := cluster.out_info
181+
grouped.bits.status(gid) := cluster.status_sum
181182
}
182183

183184
// Stage 2: delay grouped data, concat different group
184-
val delay_group_data = RegNext(group_data)
185-
val delay_group_info = RegNext(group_info)
186-
val delay_group_status = RegNext(group_status)
185+
val delay_grouped = Wire(Decoupled(new GroupBundle))
186+
PipelineConnect(grouped, delay_grouped, delay_grouped.fire)
187+
delay_grouped.ready := out.ready
188+
189+
val delay_group_data = delay_grouped.bits.data
190+
val delay_group_info = delay_grouped.bits.info
191+
val delay_group_status = delay_grouped.bits.status
187192
val info_num = delay_group_status.last.info_size
188193
val BatchStep = Wire(new BatchInfo)
189194
BatchStep.id := Batch.getTemplate.length.U
190195
BatchStep.num := info_num // unused, only for debugging
191196

192-
step_enable := info_num =/= 0.U
197+
out.valid := delay_grouped.valid && info_num =/= 0.U
193198
// append BatchStep to last step_status
194-
step_status := delay_group_status
195-
step_status.last.info_size := delay_group_status.last.info_size + 1.U
199+
out.bits.status := delay_group_status
200+
out.bits.status.last.info_size := delay_group_status.last.info_size + 1.U
196201

197202
val toCat_data = delay_group_data.take(in_group_single.size).reverse
198203
val toCat_info = delay_group_info.take(in_group_single.size).reverse
@@ -226,7 +231,7 @@ class BatchCollector(bundles: Seq[Valid[DifftestBundle]], param: BatchParam) ext
226231
} else {
227232
0.U
228233
}
229-
step_data := res_single.last | res_multi
234+
out.bits.data := res_single.last | res_multi
230235

231236
// Collect info from tail, collect(i) include last 0~i
232237
val toCollect_info = delay_group_info.reverse
@@ -235,20 +240,17 @@ class BatchCollector(bundles: Seq[Valid[DifftestBundle]], param: BatchParam) ext
235240
val info_base = if (idx == 0) BatchStep.asUInt else info_res(idx - 1)
236241
info_res(idx) := Mux(toCollect_info(idx) =/= 0.U, Cat(info_base, toCollect_info(idx)), info_base)
237242
}
238-
step_info := info_res.last
243+
out.bits.info := info_res.last
244+
out.bits.trace_info.foreach(_ := delay_grouped.bits.trace_info.get)
239245
}
240246

241247
// Assemble step_data from different cycles
242248
class BatchAssembler(
243249
param: BatchParam,
244250
config: GatewayConfig,
245251
) extends Module {
246-
val step_data = IO(Input(UInt(param.StepDataBitLen.W)))
247-
val step_info = IO(Input(UInt(param.StepInfoBitLen.W)))
248-
val step_status = IO(Input(Vec(param.StepGroupSize, new BatchStats(param))))
249-
val step_enable = IO(Input(Bool()))
250-
val step_trace_info = Option.when(config.hasReplay)(IO(Input(new DiffTraceInfo(config))))
251-
val out = IO(Output(new BatchOutput(param, config)))
252+
val in = IO(Flipped(Decoupled(new BatchStepResult(param, config))))
253+
val out = IO(Decoupled(new BatchOutput(param, config)))
252254

253255
val state_data = RegInit(0.U(param.MaxDataBitLen.W))
254256
val state_info = RegInit(0.U(param.MaxInfoBitLen.W))
@@ -261,11 +263,15 @@ class BatchAssembler(
261263
// 1. RegNext signal from BatchCollector to cut of combination logic path
262264
// 1. data/info_exceed_vec: mark whether different length fragments of step data/info exceed available space
263265
// 2. concat/remain_stats: record statistic for data/info to be concatenated to output or remained to state
264-
val delay_step_data = RegNext(step_data)
265-
val delay_step_info = RegNext(step_info)
266-
val delay_step_status = RegNext(step_status)
267-
val delay_step_enable = RegNext(step_enable)
268-
val delay_step_trace_info = Option.when(config.hasReplay)(RegNext(step_trace_info.get))
266+
val delay_step = Wire(Decoupled(new BatchStepResult(param, config)))
267+
PipelineConnect(in, delay_step, delay_step.fire)
268+
val want_tick = Wire(Bool())
269+
delay_step.ready := out.ready
270+
val delay_step_data = delay_step.bits.data
271+
val delay_step_info = delay_step.bits.info
272+
val delay_step_status = delay_step.bits.status
273+
val delay_step_enable = delay_step.valid
274+
val delay_step_trace_info = delay_step.bits.trace_info
269275
val data_bytes_avail = param.MaxDataByteLen.U -& state_status.data_bytes
270276
// Always leave space for BatchFinish, use MaxInfoSize - 1
271277
val info_size_avail = (param.MaxInfoSize - 1).U -& state_status.info_size
@@ -285,7 +291,7 @@ class BatchAssembler(
285291

286292
val step_exceed = delay_step_enable && (state_step_cnt === config.batchSize.U)
287293
val cont_exceed = data_exceed || info_exceed
288-
val state_flush = step_enable && step_status.last.data_bytes >= param.MaxDataByteLen.U // use Stage 1 bytes to flush ahead
294+
val state_flush = in.valid && in.bits.status.last.data_bytes >= param.MaxDataByteLen.U // use Stage 1 bytes to flush ahead
289295

290296
if (config.batchSplit) {
291297
val data_exceed_v = VecInit(delay_step_status.map(_.data_bytes > data_bytes_avail && delay_step_enable))
@@ -372,25 +378,27 @@ class BatchAssembler(
372378
DifftestPerf("BatchExceed_timeout", timeout.asUInt)
373379
if (config.hasReplay) DifftestPerf("BatchExceed_trace", trace_exceed.get.asUInt)
374380
}
375-
val in_replay = Option.when(config.hasReplay)(step_trace_info.get.in_replay)
381+
val in_replay = Option.when(config.hasReplay)(delay_step.bits.trace_info.get.in_replay)
376382

377-
val should_tick = timeout || state_flush || cont_exceed || step_exceed ||
383+
want_tick := timeout || state_flush || cont_exceed || step_exceed ||
378384
trace_exceed.getOrElse(false.B) || in_replay.getOrElse(false.B)
385+
val should_tick = want_tick && out.ready
379386
when(!should_tick) {
380387
timeout_count := timeout_count + 1.U
381388
}.otherwise {
382389
timeout_count := 0.U
383390
}
384391

385-
out.io.data := state_data | (append_data << (state_status.data_bytes << 3).asUInt).asUInt
386-
out.io.info := state_info | (append_info << (state_status.info_size * param.infoWidth.U)).asUInt
387-
out.enable := should_tick
388-
out.step := Mux(out.enable, finish_step, 0.U)
392+
out.bits.io.data := state_data | (append_data << (state_status.data_bytes << 3).asUInt).asUInt
393+
out.bits.io.info := state_info | (append_info << (state_status.info_size * param.infoWidth.U)).asUInt
394+
out.bits.step := Mux(out.valid, finish_step, 0.U)
395+
out.valid := want_tick
389396

390-
val state_update = delay_step_enable || state_flush || timeout
397+
val delay_step_fire = delay_step.fire
398+
val state_update = delay_step_fire || (should_tick && !delay_step_enable)
391399

392400
when(state_update) {
393-
when(delay_step_enable) {
401+
when(delay_step_fire) {
394402
when(should_tick) {
395403
state_step_cnt := next_state_step_cnt
396404
state_data := next_state_data

0 commit comments

Comments
 (0)