diff --git a/src/main/scala/Core/CSR.scala b/src/main/scala/Core/CSR.scala index 35861722..8c97e134 100644 --- a/src/main/scala/Core/CSR.scala +++ b/src/main/scala/Core/CSR.scala @@ -344,6 +344,41 @@ class CSR(implicit val conf: FlexpretConfiguration) extends Module { } } + + // Implement counting lock + val countinglock = Module(new CountingLock()).io + countinglock.driveInputDefaults() + countinglock.tid := io.rw.thread + + // increment operation, with 0 being interpreted as a reset + when (write && compare_addr(CSRs.countinglock_inc)) { + countinglock.increment := io.rw.data_in + when (io.rw.data_in === 0.U) { + countinglock.reset := true.B + } + } + + // wait operation + // low bits specify lock_id + // high bits specify what value to wait for + when (write && compare_addr(CSRs.countinglock_wait)) { + countinglock.lock_wait := true.B + countinglock.lock_id := io.rw.data_in(conf.threadBits - 1, 0) + countinglock.lock_until := io.rw.data_in >> conf.threadBits + } + + // TODO: does this play nicely with other sleep/wake operations? + // i.e. delay_until instruction, sleeper module + when (countinglock.sleep) { + sleep := true.B + } + for (tid <- 0 until conf.threads) { + when (countinglock.wake(tid)) { + wake(tid) := true.B + } + } + + // exception handling if (conf.exceptions) { when(io.exception) { diff --git a/src/main/scala/Core/countinglock.scala b/src/main/scala/Core/countinglock.scala new file mode 100644 index 00000000..e5ce5fa0 --- /dev/null +++ b/src/main/scala/Core/countinglock.scala @@ -0,0 +1,70 @@ +package flexpret.core +import chisel3._ +import chisel3.util._ + +// A simple implementation of a set of counting locks; one counting lock per thread + +class CountingLockIO(implicit val conf: FlexpretConfiguration) extends Bundle { + // current thread id, used for all 3 operations + val tid = Input(UInt(conf.threadBits.W)) + + // input for increment operation + val increment = Input(UInt(32.W)) + + // input for reset operation + val reset = Input(Bool()) + + // inputs for lock_until operation + val lock_wait = Input(Bool()) + val lock_id = Input(UInt(conf.threadBits.W)) + val lock_until = Input(UInt(32.W)) + + // outputs: sleep if the current thread should sleep, + // wake for whether any thread should wake + val sleep = Output(Bool()) + val wake = Output(Vec(conf.threads, Bool())) + + + def driveInputDefaults() = { + tid := 0.U + increment := 0.U + reset := false.B + lock_wait := false.B + lock_id := 0.U + lock_until := 0.U + } +} + +class CountingLock(implicit val conf: FlexpretConfiguration) extends Module { + val io = IO(new CountingLockIO()) + + // value of counting lock owned by each thread + val regValue = RegInit(VecInit(Seq.fill(conf.threads) { 0.U(32.W) })) + + // what value each thread is waiting for + val regUntil = RegInit(VecInit(Seq.fill(conf.threads) { 0.U(32.W) })) + + // which counting lock each thread is waiting on. own id means not waiting on anything + val regWaitingOn = RegInit(VecInit( Seq.tabulate(conf.threads)(n => n.U(conf.threadBits.W)) )) + + when(io.reset) { + regValue(io.tid) := 0.U + } .otherwise { + regValue(io.tid) := regValue(io.tid) + io.increment + } + + when (io.lock_wait) { + regUntil(io.tid) := io.lock_until + regWaitingOn(io.tid) := io.lock_id + } + + io.sleep := io.lock_wait && io.lock_until > regValue(io.tid) + for (tid <- 0 until conf.threads) { + when ((tid.U =/= io.tid) && regWaitingOn(tid) === io.tid && regUntil(tid) <= regValue(io.tid)) { + io.wake(tid) := true.B + regWaitingOn(tid) := tid.U // set to not waiting on anything + } .otherwise { + io.wake(tid) := false.B + } + } +} \ No newline at end of file diff --git a/src/main/scala/Core/instructions.scala b/src/main/scala/Core/instructions.scala index 557c5987..17e4d18c 100644 --- a/src/main/scala/Core/instructions.scala +++ b/src/main/scala/Core/instructions.scala @@ -261,6 +261,8 @@ object CSRs { val tohost = 0x51e val fromhost = 0x51f val hwlock = 0x520 + val countinglock_inc = 0x540 + val countinglock_wait = 0x541 val cycle = 0xc00 val time = 0xc01 val instret = 0xc02