Skip to content

Commit

Permalink
adding matadd
Browse files Browse the repository at this point in the history
  • Loading branch information
SeahK committed Jun 3, 2024
2 parents 1a2c0e6 + f9b387b commit 4b8d610
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 261 deletions.
7 changes: 7 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
[submodule "software/gemmini-rocc-tests"]
path = software/gemmini-rocc-tests
<<<<<<< HEAD
url = https://github.com/SeahK/gemmini-slam-tests
=======
url = https://github.com/SeahK/gemmini-slam-tests
>>>>>>> origin/spica
[submodule "software/onnxruntime-riscv"]
path = software/onnxruntime-riscv
url = https://github.com/pranav-prakash/onnxruntime-riscv.git
[submodule "software/libgemmini"]
path = software/libgemmini
url = https://github.com/ucb-bar/libgemmini.git
[submodule "software/libdma"]
path = software/libdma
url = https://github.com/ucb-bar/libdma
2 changes: 1 addition & 1 deletion software/gemmini-rocc-tests
Submodule gemmini-rocc-tests updated from 1a1a1c to 191fd0
1 change: 1 addition & 0 deletions software/libdma
Submodule libdma added at 3e0712
2 changes: 1 addition & 1 deletion src/main/scala/gemmini/AccumulatorScale.scala
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ object AccumulatorScale {
val neg_q_iexp = neg(q)
val z_iexp = (neg_q_iexp * qln2_inv).asUInt.do_>>(16).asTypeOf(q) // q is non-positive
val z_iexp_saturated = Wire(z_iexp.cloneType)
z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S, z_iexp)
z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S.asTypeOf(z_iexp), z_iexp)
val qp_iexp = q.mac(z_iexp, qln2).withWidthOf(q)
val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q)
// we dont want a rounding shift
Expand Down
76 changes: 73 additions & 3 deletions src/main/scala/gemmini/ConfigsFP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ object GemminiFPConfigs {
meshColumns = 4,

ld_queue_length = 8,
st_queue_length = 2,
st_queue_length = 4,
ex_queue_length = 8,

reservation_station_entries_ld = 8,
reservation_station_entries_st = 4,
reservation_station_entries_st = 8,
reservation_station_entries_ex = 16,

sp_banks = 4,
Expand Down Expand Up @@ -86,7 +86,54 @@ object GemminiFPConfigs {
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
)


val slamFPConfig = FP32DefaultConfig.copy(sp_capacity=CapacityInKilobytes(64), acc_capacity=CapacityInKilobytes(32), dataflow=Dataflow.WS,
//acc_scale_args=Some(defaultFPConfig.acc_scale_args.get.copy(num_scale_units=0, latency=1)),
acc_scale_args = Some(ScaleArguments((t: Float, u: Float) => {t}, 1, Float(8, 24), -1, identity = "1.0",
c_str = "((x))"
)),
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 3, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), // 4-> 3 (check)
//mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => {Mux(u > 0.U.asTypeOf(Float(8, 24)), t, 0.U.asTypeOf(Float(8,24)) - t)}, 1, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), // 2 -> 1 stage
mvin_scale_acc_args=None,
acc_singleported=false,
acc_sub_banks = 1,
acc_banks = 2,
mesh_output_delay = 2,
tile_latency = 1,
acc_latency = 3,
ex_read_from_acc=false,
ex_write_to_spad=false,
has_training_convs = false,
hardcode_d_to_garbage_addr = true,
has_loop_conv = false,
acc_read_full_width = false,
//has_loop_conv = false,
max_in_flight_mem_reqs = 16,
headerFileName = "gemmini_params_fp32.h",
num_counter = 0,
clock_gate = false //true // enable this
)


val FP32DummyConfig = slamFPConfig.copy(inputType = DummySInt(32), accType = DummySInt(32), spatialArrayOutputType = DummySInt(32),
mvin_scale_args = Some(ScaleArguments(
(t: DummySInt, f:Float) => t.dontCare,
1, Float(8, 24), -1,
identity = "1.0",
c_str = "((x)*(scale))"
)),

mvin_scale_acc_args = None,

acc_scale_args = Some(ScaleArguments(
(t: DummySInt, f:Float) => t.dontCare,
1, Float(8, 24), -1,
identity = "1.0",
c_str = "((x)*(scale))"
)),
has_loop_conv = true,
)

//FP16 Half Precision Configuration
val FP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), spatialArrayOutputType = Float(5, 11), accType = Float(8, 24),
tile_latency = 2,
Expand Down Expand Up @@ -123,6 +170,29 @@ class GemminiFP32DefaultConfig extends Config((site, here, up) => {
)
})

class SLAMFPGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiFPConfigs.slamFPConfig
) extends Config((site, here, up) => {
case BuildRoCC => up(BuildRoCC) ++ Seq(
(p: Parameters) => {
implicit val q = p
val gemmini = LazyModule(new Gemmini(gemminiConfig))
gemmini
}
)
})

class GemminiFP32DummyConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiFPConfigs.FP32DummyConfig
) extends Config((site, here, up) => {
case BuildRoCC => up(BuildRoCC) ++ Seq(
(p: Parameters) => {
implicit val q = p
val gemmini = LazyModule(new Gemmini(gemminiConfig))
gemmini
}
)
})

//===========FP16 Default Config=========
class GemminiFP16DefaultConfig extends Config((site, here, up) => {
Expand Down
41 changes: 23 additions & 18 deletions src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
val max_exs = reservation_station_entries_ex
val max_sts = reservation_station_entries_st

/*
val (conv_cmd, loop_conv_unroller_busy) = withClock (gated_clock) { LoopConv(raw_cmd, reservation_station.io.conv_ld_completed, reservation_station.io.conv_st_completed, reservation_station.io.conv_ex_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes,
Expand All @@ -153,19 +154,31 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) }
*/

val (matadd_cmd, loop_matadd_unroller_busy) = withClock (gated_clock) { LoopMatadd(conv_cmd, reservation_station.io.matadd_ld_completed, reservation_station.io.matadd_st_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_sts, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t)) }
val (conv_cmd, loop_conv_unroller_busy) = if (has_loop_conv) withClock (gated_clock) { LoopConv(raw_cmd, reservation_station.io.conv_ld_completed, reservation_station.io.conv_st_completed, reservation_station.io.conv_ex_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes,
new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits, pixel_repeats_bits), new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new ConfigMvoutRs2(acc_scale_t_bits, 32), new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t),
new ConfigExRs1(acc_scale_t_bits), new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
has_training_convs, has_max_pool, has_first_layer_optimizations, has_dw_convs) }
else (raw_cmd, false.B)

val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(matadd_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
val (matmul_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(if (has_loop_conv) conv_cmd else raw_cmd, reservation_station.io.matmul_ld_completed, reservation_station.io.matmul_st_completed, reservation_station.io.matmul_ex_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t),
new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t)) }

val (loop_cmd, loop_matadd_unroller_busy) = withClock (gated_clock) { LoopMatadd(matmul_cmd, reservation_station.io.matadd_ld_completed, reservation_station.io.matadd_st_completed,
meshRows*tileRows, coreMaxAddrBits, reservation_station_entries, max_lds, max_sts, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t),
new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t)) }

val unrolled_cmd = Queue(loop_cmd)
unrolled_cmd.ready := false.B
counters.io.event_io.connectEventSignal(CounterEvent.LOOP_MATMUL_ACTIVE_CYCLES, loop_matmul_unroller_busy)
Expand Down Expand Up @@ -262,32 +275,24 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
ex_controller.io.acc.write <> spad.module.io.acc.write

// Im2Col unit
/*
val im2col = withClock (gated_clock) { Module(new Im2Col(outer.config)) }
// Wire up Im2col
counters.io.event_io.collect(im2col.io.counter)
// im2col.io.sram_reads <> spad.module.io.srams.read
im2col.io.req <> ex_controller.io.im2col.req
ex_controller.io.im2col.resp <> im2col.io.resp
*/

// Wire arbiter for ExecuteController and Im2Col scratchpad reads
(ex_controller.io.srams.read, im2col.io.sram_reads, spad.module.io.srams.read).zipped.foreach { case (ex_read, im2col_read, spad_read) =>
val req_arb = Module(new Arbiter(new ScratchpadReadReq(n=sp_bank_entries), 2))

(ex_controller.io.srams.read, spad.module.io.srams.read).zipped.foreach { case (ex_read, spad_read) =>
val req_arb = Module(new Arbiter(new ScratchpadReadReq(n=sp_bank_entries), 1))
req_arb.io.in(0) <> ex_read.req
req_arb.io.in(1) <> im2col_read.req

spad_read.req <> req_arb.io.out

// TODO if necessary, change how the responses are handled when fromIm2Col is added to spad read interface

ex_read.resp.valid := spad_read.resp.valid
im2col_read.resp.valid := spad_read.resp.valid

ex_read.resp.bits := spad_read.resp.bits
im2col_read.resp.bits := spad_read.resp.bits

spad_read.resp.ready := ex_read.resp.ready || im2col_read.resp.ready
spad_read.resp.ready := ex_read.resp.ready
}

// Wire up controllers to ROB
Expand Down
16 changes: 8 additions & 8 deletions src/main/scala/gemmini/CounterFile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ object CounterEvent {
val B_GARBAGE_CYCLES = 36
val D_GARBAGE_CYCLES = 37

val IM2COL_MEM_CYCLES = 38
val IM2COL_ACTIVE_CYCLES = 39
val IM2COL_TRANSPOSER_WAIT_CYCLE = 40
//val IM2COL_MEM_CYCLES = 38
//val IM2COL_ACTIVE_CYCLES = 39
//val IM2COL_TRANSPOSER_WAIT_CYCLE = 40

val RESERVATION_STATION_FULL_CYCLES = 41
val RESERVATION_STATION_ACTIVE_CYCLES = 42
val RESERVATION_STATION_FULL_CYCLES = 38
val RESERVATION_STATION_ACTIVE_CYCLES = 39

val LOOP_MATMUL_ACTIVE_CYCLES = 43
val TRANSPOSE_PRELOAD_UNROLLER_ACTIVE_CYCLES = 44
val LOOP_MATMUL_ACTIVE_CYCLES = 40
val TRANSPOSE_PRELOAD_UNROLLER_ACTIVE_CYCLES = 41

val n = 45
val n = 42
}

object CounterExternal {
Expand Down
64 changes: 39 additions & 25 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val io = IO(new Bundle {
val cmd = Flipped(Decoupled(new GemminiCmd(reservation_station_entries)))

/*
val im2col = new Bundle {
val req = Decoupled(new Im2ColReadReq(config))
val resp = Flipped(Decoupled(new Im2ColReadResp(config)))
}

*/

val srams = new Bundle {
val read = Vec(sp_banks, new ScratchpadReadIO(sp_bank_entries, sp_width))
val write = Vec(sp_banks, new ScratchpadWriteIO(sp_bank_entries, sp_width, (sp_width / (aligned_to * 8)) max 1))
Expand Down Expand Up @@ -111,7 +114,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
icol := ((ocol - 1.U) * weight_stride + krow)//.asSInt
irow := ((orow - 1.U) * weight_stride + krow)//.asSInt

val im2col_turn = WireInit(0.U(9.W))
//val im2col_turn = WireInit(0.U(9.W))

val in_shift = Reg(UInt(log2Up(accType.getWidth).W))
val acc_scale = Reg(acc_scale_t)
Expand All @@ -133,7 +136,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
"Too many inputs are being fed into the single transposer we have")

//fix by input
val im2col_en = config.hasIm2Col.B && weight_stride =/= 0.U
val im2col_en = false.B //config.hasIm2Col.B && weight_stride =/= 0.U

// SRAM addresses of matmul operands
val a_address_rs1 = rs1s(a_address_place).asTypeOf(local_addr_t)
Expand Down Expand Up @@ -311,7 +314,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val b_row_is_not_all_zeros = b_fire_counter < b_rows
val d_row_is_not_all_zeros = block_size.U - 1.U - d_fire_counter < d_rows //Todo: d_fire_counter_mulpre?

val im2col_wire = io.im2col.req.ready
val im2col_wire = false.B //io.im2col.req.ready

def same_bank(addr1: LocalAddr, addr2: LocalAddr, is_garbage1: Bool, is_garbage2: Bool, start_inputting1: Bool, start_inputting2: Bool, can_be_im2colled: Boolean): Bool = {
val addr1_read_from_acc = addr1.is_acc_addr
Expand Down Expand Up @@ -394,11 +397,16 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
mul_pre_counter_lock := true.B
}

/*
when(!io.im2col.resp.bits.im2col_delay && performing_mul_pre){
mul_pre_counter_sub := Mux(mul_pre_counter_sub > 0.U, mul_pre_counter_sub - 1.U, 0.U)
}.elsewhen(io.im2col.resp.bits.im2col_delay){
mul_pre_counter_sub := 2.U
}.otherwise{mul_pre_counter_sub := 0.U}
*/
when(performing_mul_pre){
mul_pre_counter_sub := Mux(mul_pre_counter_sub > 0.U, mul_pre_counter_sub - 1.U, 0.U)
}.otherwise{mul_pre_counter_sub := 0.U}

// The last line in this (long) Boolean is just to make sure that we don't think we're done as soon as we begin firing
// TODO change when square requirement lifted
Expand All @@ -415,9 +423,13 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
}

val d_fire_counter_mulpre = WireInit(b_fire_counter)
/*
when(performing_mul_pre && !io.im2col.resp.bits.im2col_delay&&im2col_en){
d_fire_counter_mulpre := d_fire_counter - mul_pre_counter_sub
}.otherwise{d_fire_counter_mulpre := d_fire_counter}
*/
d_fire_counter_mulpre := d_fire_counter


// Scratchpad reads
for (i <- 0 until sp_banks) {
Expand Down Expand Up @@ -505,27 +517,28 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
{
val read_a = a_valid && start_inputting_a && !multiply_garbage && im2col_wire&&im2col_en //or just im2col_wire

when (read_a && !io.im2col.req.ready) {
when (read_a && !im2col_wire) {
a_ready := false.B
}

io.im2col.req.valid := read_a
io.im2col.req.bits.addr := a_address_rs1
io.im2col.req.bits.icol := icol
io.im2col.req.bits.irow := irow
io.im2col.req.bits.ocol := ocol
io.im2col.req.bits.stride := weight_stride
io.im2col.req.bits.krow := krow
io.im2col.req.bits.kdim2 := kdim2
io.im2col.req.bits.row_turn := row_turn
io.im2col.req.bits.row_left := row_left
io.im2col.req.bits.channel := channel
io.im2col.req.bits.im2col_cmd := im2col_en
io.im2col.req.bits.start_inputting := start_inputting_a
io.im2col.req.bits.weight_double_bank := weight_double_bank
io.im2col.req.bits.weight_triple_bank := weight_triple_bank

io.im2col.resp.ready := mesh.io.a.ready
/*
io.im2col.req.valid := read_a
io.im2col.req.bits.addr := a_address_rs1
io.im2col.req.bits.icol := icol
io.im2col.req.bits.irow := irow
io.im2col.req.bits.ocol := ocol
io.im2col.req.bits.stride := weight_stride
io.im2col.req.bits.krow := krow
io.im2col.req.bits.kdim2 := kdim2
io.im2col.req.bits.row_turn := row_turn
io.im2col.req.bits.row_left := row_left
io.im2col.req.bits.channel := channel
io.im2col.req.bits.im2col_cmd := im2col_en
io.im2col.req.bits.start_inputting := start_inputting_a
io.im2col.req.bits.weight_double_bank := weight_double_bank
io.im2col.req.bits.weight_triple_bank := weight_triple_bank

io.im2col.resp.ready := mesh.io.a.ready
*/
}

// FSM logic
Expand Down Expand Up @@ -802,11 +815,11 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In

val readData = VecInit(io.srams.read.map(_.resp.bits.data))
val accReadData = if (ex_read_from_acc) VecInit(io.acc.read_resp.map(_.bits.data.asUInt)) else readData
val im2ColData = io.im2col.resp.bits.a_im2col.asUInt
//val im2ColData = io.im2col.resp.bits.a_im2col.asUInt

val readValid = VecInit(io.srams.read.map(bank => ex_read_from_spad.B && bank.resp.valid && !bank.resp.bits.fromDMA))
val accReadValid = VecInit(io.acc.read_resp.map(bank => ex_read_from_acc.B && bank.valid && !bank.bits.fromDMA))
val im2ColValid = io.im2col.resp.valid
val im2ColValid = false.B //io.im2col.resp.valid

mesh_cntl_signals_q.io.deq.ready := (!cntl.a_fire || mesh.io.a.fire || !mesh.io.a.ready) &&
(!cntl.b_fire || mesh.io.b.fire || !mesh.io.b.ready) &&
Expand All @@ -829,7 +842,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
//val neg_shift_sub = block_size.U - cntl.c_rows
preload_zero_counter := wrappingAdd(preload_zero_counter, 1.U, block_size.U, dataA_valid && dataD_valid && cntl.preload_zeros && (cntl.perform_single_preload || cntl.perform_mul_pre))

val dataA_unpadded = Mux(cntl.im2colling, im2ColData, Mux(cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank)))
//val dataA_unpadded = Mux(cntl.im2colling, im2ColData, Mux(cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank)))
val dataA_unpadded = Mux(cntl.a_read_from_acc, accReadData(cntl.a_bank_acc), readData(cntl.a_bank))
val dataB_unpadded = MuxCase(readData(cntl.b_bank), Seq(cntl.accumulate_zeros -> 0.U, cntl.b_read_from_acc -> accReadData(cntl.b_bank_acc)))
val dataD_unpadded = MuxCase(readData(cntl.d_bank), Seq(cntl.preload_zeros -> 0.U, cntl.d_read_from_acc -> accReadData(cntl.d_bank_acc)))

Expand Down
1 change: 1 addition & 0 deletions src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
has_dw_convs: Boolean = true,
has_normalizations: Boolean = false,
has_first_layer_optimizations: Boolean = true,
has_loop_conv: Boolean = true,

use_firesim_simulation_counters: Boolean = false,

Expand Down
Loading

0 comments on commit 4b8d610

Please sign in to comment.