Skip to content

Commit

Permalink
improve StringPool API, safety and documentation
Browse files Browse the repository at this point in the history
- more doc comments
- compute string hash before acquiring lock
- add safety check for calling deinit while holding lock
- add stringToSliceLock
- add stringToSlice
- remove StringPool.String.empty
- remove hashString
  • Loading branch information
Techatrix committed Dec 18, 2023
1 parent 4eb4239 commit 1b9dd50
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 35 deletions.
6 changes: 3 additions & 3 deletions src/analyser/InternPool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3755,9 +3755,9 @@ fn formatId(
) @TypeOf(writer).Error!void {
_ = options;
if (fmt.len != 0) std.fmt.invalidFmtError(fmt, ctx.string);
ctx.ip.string_pool.mutex.lock();
defer ctx.ip.string_pool.mutex.unlock();
try writer.print("{}", .{std.zig.fmtId(ctx.ip.string_pool.stringToSliceUnsafe(ctx.string))});
const locked_string = ctx.ip.string_pool.stringToSliceLock(ctx.string);
defer locked_string.release(&ctx.ip.string_pool);
try std.fmt.format(writer, "{}", .{std.zig.fmtId(locked_string.slice)});
}

pub fn fmtId(ip: *InternPool, string: String) std.fmt.Formatter(formatId) {
Expand Down
133 changes: 101 additions & 32 deletions src/analyser/string_pool.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ pub const Config = struct {
MutexType: ?type = null,
};

/// The StringPool is a Data structure that stores only one copy of distinct and immutable strings i.e. `[]const u8`.
///
/// The `getOrPutString` function will intern a given string and return a unique identifier
/// that can then be used to retrieve the original string with the `stringToSlice*` functions.
pub fn StringPool(comptime config: Config) type {
return struct {
const Pool = @This();

/// A unique number that identifier a interned string.
///
/// Two interned string can be checked for equality simply by checking
/// if this identifier are equal if they both come from the same StringPool.
pub const String = enum(u32) {
empty = 0,
_,

pub fn toOptional(self: String) OptionalString {
Expand All @@ -36,7 +43,6 @@ pub fn StringPool(comptime config: Config) type {
};

pub const OptionalString = enum(u32) {
empty = 0,
none = std.math.maxInt(u32),
_,

Expand All @@ -46,35 +52,55 @@ pub fn StringPool(comptime config: Config) type {
}
};

/// asserts that `str` contains no null bytes
/// Asserts that `str` contains no null bytes.
pub fn getString(pool: *Pool, str: []const u8) ?String {
assert(std.mem.indexOfScalar(u8, str, 0) == null);

// precompute the hash before acquiring the lock
const precomputed_key_hash = std.hash_map.hashString(str);

pool.mutex.lock();
defer pool.mutex.unlock();
const index = pool.map.getKeyAdapted(str, std.hash_map.StringIndexAdapter{ .bytes = &pool.bytes }) orelse return null;

const adapter = PrecomputedStringIndexAdapter{
.bytes = &pool.bytes,
.adapted_key = str,
.precomputed_key_hash = precomputed_key_hash,
};

const index = pool.map.getKeyAdapted(str, adapter) orelse return null;
return @enumFromInt(index);
}

/// asserts that `str` contains no null bytes
/// returns `error.OutOfMemory` if adding this new string would increase the amount of allocated bytes above std.math.maxInt(u32)
/// Asserts that `str` contains no null bytes.
/// Returns `error.OutOfMemory` if adding this new string would increase the amount of allocated bytes above std.math.maxInt(u32)
pub fn getOrPutString(pool: *Pool, allocator: Allocator, str: []const u8) error{OutOfMemory}!String {
assert(std.mem.indexOfScalar(u8, str, 0) == null);

const start_index = std.math.cast(u32, pool.bytes.items.len) orelse return error.OutOfMemory;

// precompute the hash before acquiring the lock
const precomputed_key_hash = std.hash_map.hashString(str);

pool.mutex.lock();
defer pool.mutex.unlock();

const adapter = PrecomputedStringIndexAdapter{
.bytes = &pool.bytes,
.adapted_key = str,
.precomputed_key_hash = precomputed_key_hash,
};

pool.bytes.ensureUnusedCapacity(allocator, str.len + 1) catch {
// If allocation fails, try to do the lookup anyway.
const index = pool.map.getKeyAdapted(str, std.hash_map.StringIndexAdapter{ .bytes = &pool.bytes }) orelse return error.OutOfMemory;
const index = pool.map.getKeyAdapted(str, adapter) orelse return error.OutOfMemory;
return @enumFromInt(index);
};

const gop = try pool.map.getOrPutContextAdapted(
allocator,
str,
std.hash_map.StringIndexAdapter{ .bytes = &pool.bytes },
adapter,
std.hash_map.StringIndexContext{ .bytes = &pool.bytes },
);

Expand All @@ -86,13 +112,7 @@ pub fn StringPool(comptime config: Config) type {
return @enumFromInt(gop.key_ptr.*);
}

pub fn hashString(pool: *Pool, hasher: anytype, index: String) void {
pool.mutex.lock();
defer pool.mutex.unlock();
const str = pool.stringToSliceUnsafe(index);
hasher.update(str);
}

/// Caller owns the memory.
pub fn stringToSliceAlloc(pool: *Pool, allocator: Allocator, index: String) Allocator.Error![]const u8 {
pool.mutex.lock();
defer pool.mutex.unlock();
Expand All @@ -101,6 +121,7 @@ pub fn StringPool(comptime config: Config) type {
return try allocator.dupe(u8, std.mem.sliceTo(string_bytes + start, 0));
}

/// Caller owns the memory.
pub fn stringToSliceAllocZ(pool: *Pool, allocator: Allocator, index: String) Allocator.Error![:0]const u8 {
pool.mutex.lock();
defer pool.mutex.unlock();
Expand All @@ -109,10 +130,41 @@ pub fn StringPool(comptime config: Config) type {
return try allocator.dupeZ(u8, std.mem.sliceTo(string_bytes + start, 0));
}

/// storage a slice that points into the internal storage of the `StringPool`.
/// always call `release` method to unlock the `StringPool`.
///
/// see `stringToSliceLock`
pub const LockedString = struct {
slice: [:0]const u8,

pub fn release(locked_string: LockedString, pool: *Pool) void {
_ = locked_string;
pool.mutex.unlock();
}
};

/// returns the underlying slice from an interned string
/// equal strings are guaranteed to share the same storage
///
/// Will lock the `StringPool` until the `release` method is called on the returned locked string.
pub fn stringToSliceLock(pool: *Pool, index: String) LockedString {
pool.mutex.lock();
return .{ .slice = pool.stringToSliceUnsafe(index) };
}

/// returns the underlying slice from an interned string
/// equal strings are guaranteed to share the same storage
///
/// only callable when thread safety is disabled.
pub fn stringToSlice(pool: *Pool, index: String) [:0]const u8 {
if (config.thread_safe) @compileError("use stringToSliceLock instead");
return pool.stringToSliceUnsafe(index);
}

/// returns the underlying slice from an interned string
/// equal strings are guaranteed to share the same storage
pub fn stringToSliceUnsafe(pool: *Pool, index: String) [:0]const u8 {
std.debug.assert(@intFromEnum(index) < pool.bytes.items.len);
assert(@intFromEnum(index) < pool.bytes.items.len);
const string_bytes: [*:0]u8 = @ptrCast(pool.bytes.items.ptr);
const start = @intFromEnum(index);
return std.mem.sliceTo(string_bytes + start, 0);
Expand All @@ -125,6 +177,11 @@ pub fn StringPool(comptime config: Config) type {
pub fn deinit(pool: *Pool, allocator: Allocator) void {
pool.bytes.deinit(allocator);
pool.map.deinit(allocator);
if (builtin.mode == .Debug and !builtin.single_threaded and config.thread_safe) {
// detect deadlock when calling deinit while holding the lock
pool.mutex.lock();
pool.mutex.unlock();
}
pool.* = undefined;
}

Expand All @@ -147,34 +204,55 @@ pub fn StringPool(comptime config: Config) type {

fn print(ctx: FormatContext, comptime fmt_str: []const u8, _: std.fmt.FormatOptions, writer: anytype) @TypeOf(writer).Error!void {
if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, ctx.string);
ctx.pool.mutex.lock();
defer ctx.pool.mutex.unlock();
try writer.writeAll(ctx.pool.stringToSliceUnsafe(ctx.string));
const locked_string = ctx.pool.stringToSliceLock(ctx.string);
defer locked_string.release(ctx.pool);
try writer.writeAll(locked_string.slice);
}
};
}

/// same as `std.hash_map.StringIndexAdapter` but the hash of the adapted key is precomputed
const PrecomputedStringIndexAdapter = struct {
bytes: *const std.ArrayListUnmanaged(u8),
adapted_key: []const u8,
precomputed_key_hash: u64,

pub fn eql(self: @This(), a_slice: []const u8, b: u32) bool {
const b_slice = std.mem.sliceTo(@as([*:0]const u8, @ptrCast(self.bytes.items.ptr)) + b, 0);
return std.mem.eql(u8, a_slice, b_slice);
}

pub fn hash(self: @This(), adapted_key: []const u8) u64 {
assert(std.mem.eql(u8, self.adapted_key, adapted_key));
return self.precomputed_key_hash;
}
};

test StringPool {
const gpa = std.testing.allocator;
var pool = StringPool(.{}){};
var pool = StringPool(.{ .thread_safe = false }){};
defer pool.deinit(gpa);

const str = "All Your Codebase Are Belong To Us";
const index = try pool.getOrPutString(gpa, str);
try std.testing.expectEqualStrings(str, pool.stringToSliceUnsafe(index));

const locked_string = pool.stringToSliceLock(index);
defer locked_string.release(&pool);

try std.testing.expectEqualStrings(str, locked_string.slice);
try std.testing.expectFmt(str, "{}", .{index.fmt(&pool)});
}

test "StringPool - check interning" {
const gpa = std.testing.allocator;
var pool = StringPool(.{}){};
var pool = StringPool(.{ .thread_safe = false }){};
defer pool.deinit(gpa);

const str = "All Your Codebase Are Belong To Us";
const index1 = try pool.getOrPutString(gpa, str);
const index2 = try pool.getOrPutString(gpa, str);
const index3 = pool.getString(str).?;
const storage1 = pool.stringToSliceUnsafe(index1);
const storage1 = pool.stringToSlice(index1);
const storage2 = pool.stringToSliceUnsafe(index2);

try std.testing.expectEqual(index1, index2);
Expand All @@ -185,15 +263,6 @@ test "StringPool - check interning" {
try std.testing.expectEqual(storage1.len, storage2.len);
}

test "StringPool - empty string" {
if (true) return error.SkipZigTest; // TODO
const gpa = std.testing.allocator;
var pool = StringPool(.{}){};
defer pool.deinit(gpa);

try std.testing.expectEqualStrings("", pool.stringToSliceUnsafe(.empty));
}

test "StringPool - getOrPut on existing string without allocation" {
const gpa = std.testing.allocator;
var failing_gpa = std.testing.FailingAllocator.init(gpa, .{ .fail_index = 0 });
Expand Down

0 comments on commit 1b9dd50

Please sign in to comment.