Skip to content

Commit

Permalink
works
Browse files Browse the repository at this point in the history
  • Loading branch information
shayanhabibi committed Dec 20, 2021
1 parent 4364a57 commit 304a0cd
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 61 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,37 @@ for freeing. It is only 8 bytes large making it incredibly memory efficient.

Using futexes makes this primitive truly faster and more efficient than mutexes for its
use case.

## Principle

A thread will acquire the capability to perform an action (write, read, free)
and then **wait** for that action to be allowed. Once the thread has completed
its action, it then releases the capability which will then allow the following
action to be completed.

**Principally, the wrflock is a state machine.**

## Usage

> **Note: the api has not been finalised and is subject to change**
Example is for write, the same can be done for read and free by changing the
prefix letter.

```nim
let lock = initWRFLock()
if lock.wAcquire(): # all operations return bools; they are discardable if you
# know what you're doing.
lock.wWait()
# do write things here
lock.wRelease()
```

By default, the Wait operations for all 3 actions (write, read, free) are blocking
using a futex. You can pass flags to change this to just have the thread yield
to the scheduler for any of the actions.

```nim
let lock = initWRFLock([wWaitYield, rWaitYield, fWaitYield])
```
87 changes: 80 additions & 7 deletions tests/test.nim
Original file line number Diff line number Diff line change
@@ -1,11 +1,84 @@
import std/osproc
import std/strutils
import std/logging
import std/atomics
import std/os
import std/macros

import balls

import wrflock

let v = initWRFLock()
echo v.facquire()
echo v.wacquire()
echo v.wrelease()
echo v.racquire()
echo v.racquire()
echo v.racquire()
const threadCount = 6

addHandler newConsoleLogger()
setLogFilter:
when defined(danger):
lvlNotice
elif defined(release):
lvlInfo
else:
lvlDebug

var lock {.global.} = initWRFLock([wWaitYield, fWaitYield, rWaitYield])
var counter {.global.}: Atomic[int]

proc writeLock() {.thread.} =
sleep(1000)
doassert lock.wAcquire()
lock.wWait()
discard counter.fetchAdd(1)
doassert lock.wRelease()

proc readLock() {.thread.} =
sleep(200)
doassert lock.rAcquire()
lock.rWait()
doassert counter.load() == 1, "lock allowed read before it was written to"
echo counter.load()
doassert lock.rRelease()

proc freeLock() {.thread.} =
sleep(500)
doassert lock.fAcquire()
lock.fWait()
counter.store(-10000)
doassert lock.fRelease()

# try to delay a reasonable amount of time despite platform

template expectCounter(n: int): untyped =
## convenience
try:
check counter.load == n
except Exception:
checkpoint " counter: ", load counter
checkpoint "expected: ", n
raise

suite "wrflock":
block:
## See if it works

var threads: seq[Thread[void]]
newSeq(threads, threadCount)

counter.store 0

var i: int
for thread in threads.mitems:
if i == 0:
createThread(thread, writeLock)
elif i == threadCount - 1:
createThread(thread, freeLock)
else:
createThread(thread, readLock)
inc i
checkpoint "created $# threads" % [ $threadCount ]

for thread in threads.mitems:
joinThread thread
checkpoint "joined $# threads" % [ $threadCount ]


expectCounter -10000
87 changes: 37 additions & 50 deletions wrflock.nim
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import std/times
import wtbanland/futex

import wrflock/spec
export wWaitBlock, wWaitYield, rWaitBlock, rWaitYield, fWaitBlock, fWaitYield

type
WRFLockObj* = object
Expand All @@ -24,24 +25,24 @@ proc `[]`(lock: WRFLock, idx: int): var uint32 {.inline.} =
# ============================================================================ #
# Define Constructors and Destructors
# ============================================================================ #
proc initWRFLockObj(lock: var WRFLockObj; waitType: int; pshared: bool) =
proc initWRFLockObj(lock: var WRFLockObj; waitType: openArray[int]; pshared: bool) =
if pshared:
lock.data = privateMask64 or nextStateWriteMask64
lock.data = 0u or nextStateWriteMask64
else:
lock.data = 0u
lock.data = privateMask64 or nextStateWriteMask64

if (waitType and wWaitYield) != 0:
if wWaitYield in waitType:
lock.data = lock.data or wWaitYieldMask64
if (waitType and rWaitYield) != 0:
if rWaitYield in waitType:
lock.data = lock.data or rWaitYieldMask64
if (waitType and fWaitYield) != 0:
if fWaitYield in waitType:
lock.data = lock.data or fWaitYieldMask64

proc initWRFLockObj(waitType: int = 0; pshared: bool = false): WRFLockObj =
proc initWRFLockObj(waitType: openArray[int]; pshared: bool = false): WRFLockObj =
result = WRFLockObj()
initWRFLockObj(result, waitType, pshared)

proc initWRFLock*(waitType: int = 0; pshared: bool = false): WRFLock =
proc initWRFLock*(waitType: openArray[int] = []; pshared: bool = false): WRFLock =
result = createShared(WRFLockObj)
result[] = initWRFLockObj(waitType, pshared)

Expand All @@ -55,18 +56,16 @@ proc wAcquire*(lock: WRFLock): bool {.discardable.} =
var newData, data: uint32
data = lock.loadState

template lop: untyped =
while true:
if (data and wrAcquireValueMask32) != 0u:
return false # Overflow error
newData = data or wrAcquireValueMask32
if (newData and frAcquireValueMask32) != 0u:
newData = newData or rdNxtLoopFlagMask32
if (newData and nextStateWriteMask32) != 0u:
newData = newData xor (nextStateWriteMask32 or currStateWriteMask32)

lop()
while not lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
lop()
if lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
break

result = true

Expand All @@ -81,43 +80,37 @@ proc rAcquire*(lock: WRFLock): bool {.discardable.} =
wait(lock[stateOffset].addr, data)
data = lock.loadState

if (data and rdAcquireCounterMask32) == rdAcquireCounterMask32:
return false # Overflow error
newData = data + (1 shl rdAcquireCounterShift32)

while not lock[countersOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
while true:
if (data and rdAcquireCounterMask32) == rdAcquireCounterMask32:
return false # Overflow error
newData = data + (1 shl rdAcquireCounterShift32)

if lock[countersOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
break

data = lock.loadState

newData = data or rdAcquireValueMask32
if (newData and nextStateReadFreeMask32) != 0u:
newData = newData xor (nextStateReadFreeMask32 or currStateReadMask32)
while not lock[countersOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
while true:
newData = data or rdAcquireValueMask32
if (newData and nextStateReadFreeMask32) != 0u:
newData = newData xor (nextStateReadFreeMask32 or currStateReadMask32)

if lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
break

result = true

proc fAcquire*(lock: WRFLock): bool {.discardable.} =
var newData, data: uint32
data = lock.loadState

if (data and frAcquireValueMask32) != 0u:
return false # Overflow error
newData = data or frAcquireValueMask32
if (newData and nextStateReadFreeMask32) != 0u:
newData = newData xor (nextStateReadFreeMask32 or currStateFreeMask32)
while not lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
while true:
if (data and frAcquireValueMask32) != 0u:
return false # Overflow error
newData = data or frAcquireValueMask32
if (newData and nextStateReadFreeMask32) != 0u:
newData = newData xor (nextStateReadFreeMask32 or currStateFreeMask32)

if lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELAXED, ATOMIC_RELAXED):
break

result = true

# ============================================================================ #
Expand All @@ -127,7 +120,7 @@ proc wRelease*(lock: WRFLock): bool {.discardable.} =
var newData, data: uint32
data = lock.loadState

template lop: untyped =
while true:
if (data and wrAcquireValueMask32) == 0u:
return false # Overflow error
newData = data and not(wrAcquireValueMask32 or currStateWriteMask32 or rdNxtLoopFlagMask32)
Expand All @@ -136,11 +129,9 @@ proc wRelease*(lock: WRFLock): bool {.discardable.} =
elif (newData and frAcquireValueMask32) != 0u:
newData = newData or currStateFreeMask32
else:
newData = nextStateReadFreeMask32

lop()
while not lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELEASE, ATOMIC_RELAXED):
lop()
newData = newData or nextStateReadFreeMask32
if lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELEASE, ATOMIC_RELAXED):
break

if (
(
Expand All @@ -160,7 +151,7 @@ proc rRelease*(lock: WRFLock): bool {.discardable.} =
var newData, data: uint
data = lock.data.addr.atomicLoadN(ATOMIC_RELAXED)

template lop: untyped =
while true:
if (data and rdAcquireCounterMask64) == 0u:
return false # Overflow error
newData = data - (1 shl rdAcquireCounterShift64)
Expand All @@ -170,11 +161,9 @@ proc rRelease*(lock: WRFLock): bool {.discardable.} =
newData = newData xor (currStateReadMask64 or currStateFreeMask64)
else:
newData = newData xor (currStateReadMask64 or nextStateReadFreeMask64)

lop()
while not lock.data.addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELEASE, ATOMIC_RELAXED):
lop()

if lock.data.addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELEASE, ATOMIC_RELAXED):
break

if (
((newData and fWaitYieldMask64) == 0u) and
((newData and currStateFreeMask64) != 0u)
Expand All @@ -187,18 +176,16 @@ proc fRelease*(lock: WRFLock): bool {.discardable.} =
var newData, data: uint32
data = lock.loadState

template lop: untyped =
if (data and frAcquireValueMask32) != 0u:
while true:
if (data and frAcquireValueMask32) == 0u:
return false # Overflow error
newData = data and not(frAcquireValueMask32 or currStateFreeMask32)
if (newData and wrAcquireValueMask32) != 0u:
newData = newData or currStateWriteMask32
else:
newData = newData or nextStateWriteMask32

lop()
while not lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELEASE, ATOMIC_RELAXED):
lop()
if lock[stateOffset].addr.atomicCompareExchange(data.addr, newdata.addr, true, ATOMIC_RELEASE, ATOMIC_RELAXED):
break

if (
((newData and wWaitYieldMask32) == 0u) and
Expand All @@ -225,7 +212,7 @@ proc wTimeWait*(lock: WRFLock, time: int = 0): bool {.discardable.} =
result = true
break
if (data and wWaitYieldMask32) == 0u:
wait(lock[stateOffset].addr, data)
wait(lock[stateOffset].addr, data) # TODO implement time in wait futexes
else:
if time > 0:
if getTime() > (stime + dur):
Expand Down
4 changes: 2 additions & 2 deletions wrflock.nimble
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = "0.0.0"
version = "0.1.0"
author = "Shayan Habibi"
description = ""
description = "Write, Read, Free lock primitive"
license = "MIT"
3 changes: 1 addition & 2 deletions wrflock/spec.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ const
fWaitBlock* = 16
fWaitYield* = 32


when cpuEndian == littleEndian:
const
countersOffset* = 0
Expand Down Expand Up @@ -81,7 +80,7 @@ const

nextStateWriteMask64*: uint = makeFlags(nextStateWriteMask32, true)
nextStateReadFreeMask64*: uint = makeFlags(nextStateReadFreeMask32, true)
nextStateValueMask64*: uint = nextStateWriteMask64 or nextStateReadFreeMask64
nextStateValueMask64*: uint = nextStateWriteMask64 or nextStateReadFreeMask64 # or nextStateFreeMask64

currStateWriteMask32*: uint32 = 0x00000001
currStateReadMask32*: uint32 = 0x00000002
Expand Down

0 comments on commit 304a0cd

Please sign in to comment.