Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ZRANK to avoid path comparisons #1389

Merged
merged 8 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,10 @@ typedef struct zskiplistNode {
struct zskiplistNode *backward;
struct zskiplistLevel {
struct zskiplistNode *forward;
/* At each level we keep the span, which is the number of elements which are on the "subtree"
* from this node at this level to the next node at the same level.
* One exception is the value at level 0. In level 0 the span can only be 1 or 0 (in case the last elements in the list)
* So we use it in order to hold the height of the node, which is the number of levels. */
unsigned long span;
} level[];
} zskiplistNode;
Expand Down
104 changes: 82 additions & 22 deletions src/t_zset.c
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,51 @@ void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap);
zskiplistNode *zslGetElementByRankFromNode(zskiplistNode *start_node, int start_level, unsigned long rank);
zskiplistNode *zslGetElementByRank(zskiplist *zsl, unsigned long rank);

static inline unsigned long zslGetNodeSpanAtLevel(zskiplistNode *x, int level) {
/* We use the level 0 span in order to hold the node height, so in case the span is requested on
* level 0 and this is not the last node we return 1 and 0 otherwise. For the rest of the levels we just return
* the recorded span in that level. */
if (level > 0) return x->level[level].span;
return x->level[level].forward ? 1 : 0;
}

static inline void zslSetNodeSpanAtLevel(zskiplistNode *x, int level, unsigned long span) {
/* We use the level 0 span in order to hold the node height, so we avoid overriding it. */
if (level > 0)
x->level[level].span = span;
}

static inline void zslIncrNodeSpanAtLevel(zskiplistNode *x, int level, unsigned long incr) {
/* We use the level 0 span in order to hold the node height, so we avoid overriding it. */
if (level > 0)
x->level[level].span += incr;
}

static inline void zslDecrNodeSpanAtLevel(zskiplistNode *x, int level, unsigned long decr) {
/* We use the level 0 span in order to hold the node height, so we avoid overriding it. */
if (level > 0)
x->level[level].span -= decr;
}

static inline unsigned long zslGetNodeHeight(zskiplistNode *x) {
/* Since the span at level 0 is always 1 (or 0 for the last node), this
* field is instead used for storing the height of the node. */
return x->level[0].span;
}

static inline void zslSetNodeHeight(zskiplistNode *x, int height) {
/* Since the span at level 0 is always 1 (or 0 for the last node), this
* field is instead used for storing the height of the node. */
x->level[0].span = height;
}

/* Create a skiplist node with the specified number of levels.
* The SDS string 'ele' is referenced by the node after the call. */
zskiplistNode *zslCreateNode(int level, double score, sds ele) {
zskiplistNode *zn = zmalloc(sizeof(*zn) + level * sizeof(struct zskiplistLevel));
zskiplistNode *zslCreateNode(int height, double score, sds ele) {
zskiplistNode *zn = zmalloc(sizeof(*zn) + height * sizeof(struct zskiplistLevel));
zn->score = score;
zn->ele = ele;
zslSetNodeHeight(zn, height);
return zn;
}

Expand Down Expand Up @@ -147,7 +186,7 @@ zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) {
while (x->level[i].forward &&
(x->level[i].forward->score < score ||
(x->level[i].forward->score == score && sdscmp(x->level[i].forward->ele, ele) < 0))) {
rank[i] += x->level[i].span;
rank[i] += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
update[i] = x;
Expand All @@ -161,23 +200,24 @@ zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) {
for (i = zsl->level; i < level; i++) {
rank[i] = 0;
update[i] = zsl->header;
update[i]->level[i].span = zsl->length;
zslSetNodeSpanAtLevel(update[i], i, zsl->length);
}
zsl->level = level;
zslSetNodeHeight(zsl->header, level);
}
x = zslCreateNode(level, score, ele);
for (i = 0; i < level; i++) {
x->level[i].forward = update[i]->level[i].forward;
update[i]->level[i].forward = x;

/* update span covered by update[i] as x is inserted here */
x->level[i].span = update[i]->level[i].span - (rank[0] - rank[i]);
update[i]->level[i].span = (rank[0] - rank[i]) + 1;
zslSetNodeSpanAtLevel(x, i, zslGetNodeSpanAtLevel(update[i], i) - (rank[0] - rank[i]));
zslSetNodeSpanAtLevel(update[i], i, (rank[0] - rank[i]) + 1);
}

/* increment span for untouched levels */
for (i = level; i < zsl->level; i++) {
update[i]->level[i].span++;
zslIncrNodeSpanAtLevel(update[i], i, 1);
}

x->backward = (update[0] == zsl->header) ? NULL : update[0];
Expand All @@ -195,10 +235,10 @@ void zslDeleteNode(zskiplist *zsl, zskiplistNode *x, zskiplistNode **update) {
int i;
for (i = 0; i < zsl->level; i++) {
if (update[i]->level[i].forward == x) {
update[i]->level[i].span += x->level[i].span - 1;
zslIncrNodeSpanAtLevel(update[i], i, zslGetNodeSpanAtLevel(x, i) - 1);
update[i]->level[i].forward = x->level[i].forward;
} else {
update[i]->level[i].span -= 1;
zslDecrNodeSpanAtLevel(update[i], i, 1);
}
}
if (x->level[0].forward) {
Expand Down Expand Up @@ -336,7 +376,7 @@ zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n) {
x = zsl->header;
i = zsl->level - 1;
while (x->level[i].forward && !zslValueGteMin(x->level[i].forward->score, range)) {
edge_rank += x->level[i].span;
edge_rank += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
/* Remember the last node which has zsl->level-1 levels and its rank. */
Expand All @@ -348,7 +388,7 @@ zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n) {
/* Go forward while *OUT* of range. */
while (x->level[i].forward && !zslValueGteMin(x->level[i].forward->score, range)) {
/* Count the rank of the last element smaller than the range. */
edge_rank += x->level[i].span;
edge_rank += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
}
Expand All @@ -372,7 +412,7 @@ zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n) {
/* Go forward while *IN* range. */
while (x->level[i].forward && zslValueLteMax(x->level[i].forward->score, range)) {
/* Count the rank of the last element in range. */
edge_rank += x->level[i].span;
edge_rank += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
}
Expand Down Expand Up @@ -464,8 +504,8 @@ unsigned long zslDeleteRangeByRank(zskiplist *zsl, unsigned int start, unsigned

x = zsl->header;
for (i = zsl->level - 1; i >= 0; i--) {
while (x->level[i].forward && (traversed + x->level[i].span) < start) {
traversed += x->level[i].span;
while (x->level[i].forward && (traversed + zslGetNodeSpanAtLevel(x, i)) < start) {
traversed += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
update[i] = x;
Expand Down Expand Up @@ -499,7 +539,7 @@ unsigned long zslGetRank(zskiplist *zsl, double score, sds ele) {
while (x->level[i].forward &&
(x->level[i].forward->score < score ||
(x->level[i].forward->score == score && sdscmp(x->level[i].forward->ele, ele) <= 0))) {
rank += x->level[i].span;
rank += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}

Expand All @@ -511,6 +551,18 @@ unsigned long zslGetRank(zskiplist *zsl, double score, sds ele) {
return 0;
}

/* Find the rank for a specific skiplist node. */
unsigned long zslGetRankByNode(zskiplist *zsl, zskiplistNode *x) {
int i = zslGetNodeHeight(x) - 1;
unsigned long rank = zslGetNodeSpanAtLevel(x, i);
while (x->level[zslGetNodeHeight(x) - 1].forward) {
x = x->level[zslGetNodeHeight(x) - 1].forward;
rank += zslGetNodeSpanAtLevel(x, zslGetNodeHeight(x) - 1);
}
rank = zsl->length - rank;
return rank;
}

/* Finds an element by its rank from start node. The rank argument needs to be 1-based. */
zskiplistNode *zslGetElementByRankFromNode(zskiplistNode *start_node, int start_level, unsigned long rank) {
zskiplistNode *x;
Expand All @@ -519,8 +571,8 @@ zskiplistNode *zslGetElementByRankFromNode(zskiplistNode *start_node, int start_

x = start_node;
for (i = start_level; i >= 0; i--) {
while (x->level[i].forward && (traversed + x->level[i].span) <= rank) {
traversed += x->level[i].span;
while (x->level[i].forward && (traversed + zslGetNodeSpanAtLevel(x, i)) <= rank) {
traversed += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
if (traversed == rank) {
Expand Down Expand Up @@ -690,7 +742,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) {
x = zsl->header;
i = zsl->level - 1;
while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) {
edge_rank += x->level[i].span;
edge_rank += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
/* Remember the last node which has zsl->level-1 levels and its rank. */
Expand All @@ -702,7 +754,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) {
/* Go forward while *OUT* of range. */
while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) {
/* Count the rank of the last element smaller than the range. */
edge_rank += x->level[i].span;
edge_rank += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
}
Expand All @@ -726,7 +778,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) {
/* Go forward while *IN* range. */
while (x->level[i].forward && zslLexValueLteMax(x->level[i].forward->ele, range)) {
/* Count the rank of the last element in range. */
edge_rank += x->level[i].span;
edge_rank += zslGetNodeSpanAtLevel(x, i);
x = x->level[i].forward;
}
}
Expand Down Expand Up @@ -1173,6 +1225,13 @@ unsigned char *zzlDeleteRangeByRank(unsigned char *zl, unsigned int start, unsig
* Common sorted set API
*----------------------------------------------------------------------------*/

/* Utility function used for mapping the hashtable entry to the matching skiplist node.
* For example, this is used in case of ZRANK query. */
static inline zskiplistNode *zsetGetSLNodeByEntry(dictEntry *de) {
char *score_ref = ((char *)dictGetVal(de));
return (zskiplistNode *)(score_ref - offsetof(zskiplistNode, score));
}

unsigned long zsetLength(const robj *zobj) {
unsigned long length = 0;
if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
Expand Down Expand Up @@ -1603,8 +1662,9 @@ long zsetRank(robj *zobj, sds ele, int reverse, double *output_score) {

de = dictFind(zs->dict, ele);
if (de != NULL) {
score = *(double *)dictGetVal(de);
rank = zslGetRank(zsl, score, ele);
zskiplistNode *n = zsetGetSLNodeByEntry(de);
score = n->score;
rank = zslGetRankByNode(zsl, n);
/* Existing elements always have a rank. */
serverAssert(rank != 0);
if (output_score) *output_score = score;
Expand Down
Loading