diff --git a/src/server.h b/src/server.h index 896ff735b3..44de6eada1 100644 --- a/src/server.h +++ b/src/server.h @@ -1449,6 +1449,10 @@ typedef struct zskiplistNode { struct zskiplistNode *backward; struct zskiplistLevel { struct zskiplistNode *forward; + /* At each level we keep the span, which is the number of elements which are on the "subtree" + * from this node at this level to the next node at the same level. + * One exception is the value at level 0. In level 0 the span can only be 1 or 0 (in case the last elements in the list) + * So we use it in order to hold the height of the node, which is the number of levels. */ unsigned long span; } level[]; } zskiplistNode; diff --git a/src/t_zset.c b/src/t_zset.c index a1e71208cb..36a9bfffb1 100644 --- a/src/t_zset.c +++ b/src/t_zset.c @@ -72,12 +72,51 @@ void zsetConvertAndExpand(robj *zobj, int encoding, unsigned long cap); zskiplistNode *zslGetElementByRankFromNode(zskiplistNode *start_node, int start_level, unsigned long rank); zskiplistNode *zslGetElementByRank(zskiplist *zsl, unsigned long rank); +static inline unsigned long zslGetNodeSpanAtLevel(zskiplistNode *x, int level) { + /* We use the level 0 span in order to hold the node height, so in case the span is requested on + * level 0 and this is not the last node we return 1 and 0 otherwise. For the rest of the levels we just return + * the recorded span in that level. */ + if (level > 0) return x->level[level].span; + return x->level[level].forward ? 1 : 0; +} + +static inline void zslSetNodeSpanAtLevel(zskiplistNode *x, int level, unsigned long span) { + /* We use the level 0 span in order to hold the node height, so we avoid overriding it. */ + if (level > 0) + x->level[level].span = span; +} + +static inline void zslIncrNodeSpanAtLevel(zskiplistNode *x, int level, unsigned long incr) { + /* We use the level 0 span in order to hold the node height, so we avoid overriding it. */ + if (level > 0) + x->level[level].span += incr; +} + +static inline void zslDecrNodeSpanAtLevel(zskiplistNode *x, int level, unsigned long decr) { + /* We use the level 0 span in order to hold the node height, so we avoid overriding it. */ + if (level > 0) + x->level[level].span -= decr; +} + +static inline unsigned long zslGetNodeHeight(zskiplistNode *x) { + /* Since the span at level 0 is always 1 (or 0 for the last node), this + * field is instead used for storing the height of the node. */ + return x->level[0].span; +} + +static inline void zslSetNodeHeight(zskiplistNode *x, int height) { + /* Since the span at level 0 is always 1 (or 0 for the last node), this + * field is instead used for storing the height of the node. */ + x->level[0].span = height; +} + /* Create a skiplist node with the specified number of levels. * The SDS string 'ele' is referenced by the node after the call. */ -zskiplistNode *zslCreateNode(int level, double score, sds ele) { - zskiplistNode *zn = zmalloc(sizeof(*zn) + level * sizeof(struct zskiplistLevel)); +zskiplistNode *zslCreateNode(int height, double score, sds ele) { + zskiplistNode *zn = zmalloc(sizeof(*zn) + height * sizeof(struct zskiplistLevel)); zn->score = score; zn->ele = ele; + zslSetNodeHeight(zn, height); return zn; } @@ -147,7 +186,7 @@ zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) { while (x->level[i].forward && (x->level[i].forward->score < score || (x->level[i].forward->score == score && sdscmp(x->level[i].forward->ele, ele) < 0))) { - rank[i] += x->level[i].span; + rank[i] += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } update[i] = x; @@ -161,9 +200,10 @@ zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) { for (i = zsl->level; i < level; i++) { rank[i] = 0; update[i] = zsl->header; - update[i]->level[i].span = zsl->length; + zslSetNodeSpanAtLevel(update[i], i, zsl->length); } zsl->level = level; + zslSetNodeHeight(zsl->header, level); } x = zslCreateNode(level, score, ele); for (i = 0; i < level; i++) { @@ -171,13 +211,13 @@ zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) { update[i]->level[i].forward = x; /* update span covered by update[i] as x is inserted here */ - x->level[i].span = update[i]->level[i].span - (rank[0] - rank[i]); - update[i]->level[i].span = (rank[0] - rank[i]) + 1; + zslSetNodeSpanAtLevel(x, i, zslGetNodeSpanAtLevel(update[i], i) - (rank[0] - rank[i])); + zslSetNodeSpanAtLevel(update[i], i, (rank[0] - rank[i]) + 1); } /* increment span for untouched levels */ for (i = level; i < zsl->level; i++) { - update[i]->level[i].span++; + zslIncrNodeSpanAtLevel(update[i], i, 1); } x->backward = (update[0] == zsl->header) ? NULL : update[0]; @@ -195,10 +235,10 @@ void zslDeleteNode(zskiplist *zsl, zskiplistNode *x, zskiplistNode **update) { int i; for (i = 0; i < zsl->level; i++) { if (update[i]->level[i].forward == x) { - update[i]->level[i].span += x->level[i].span - 1; + zslIncrNodeSpanAtLevel(update[i], i, zslGetNodeSpanAtLevel(x, i) - 1); update[i]->level[i].forward = x->level[i].forward; } else { - update[i]->level[i].span -= 1; + zslDecrNodeSpanAtLevel(update[i], i, 1); } } if (x->level[0].forward) { @@ -336,7 +376,7 @@ zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n) { x = zsl->header; i = zsl->level - 1; while (x->level[i].forward && !zslValueGteMin(x->level[i].forward->score, range)) { - edge_rank += x->level[i].span; + edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } /* Remember the last node which has zsl->level-1 levels and its rank. */ @@ -348,7 +388,7 @@ zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n) { /* Go forward while *OUT* of range. */ while (x->level[i].forward && !zslValueGteMin(x->level[i].forward->score, range)) { /* Count the rank of the last element smaller than the range. */ - edge_rank += x->level[i].span; + edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } } @@ -372,7 +412,7 @@ zskiplistNode *zslNthInRange(zskiplist *zsl, zrangespec *range, long n) { /* Go forward while *IN* range. */ while (x->level[i].forward && zslValueLteMax(x->level[i].forward->score, range)) { /* Count the rank of the last element in range. */ - edge_rank += x->level[i].span; + edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } } @@ -464,8 +504,8 @@ unsigned long zslDeleteRangeByRank(zskiplist *zsl, unsigned int start, unsigned x = zsl->header; for (i = zsl->level - 1; i >= 0; i--) { - while (x->level[i].forward && (traversed + x->level[i].span) < start) { - traversed += x->level[i].span; + while (x->level[i].forward && (traversed + zslGetNodeSpanAtLevel(x, i)) < start) { + traversed += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } update[i] = x; @@ -499,7 +539,7 @@ unsigned long zslGetRank(zskiplist *zsl, double score, sds ele) { while (x->level[i].forward && (x->level[i].forward->score < score || (x->level[i].forward->score == score && sdscmp(x->level[i].forward->ele, ele) <= 0))) { - rank += x->level[i].span; + rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } @@ -511,6 +551,18 @@ unsigned long zslGetRank(zskiplist *zsl, double score, sds ele) { return 0; } +/* Find the rank for a specific skiplist node. */ +unsigned long zslGetRankByNode(zskiplist *zsl, zskiplistNode *x) { + int i = zslGetNodeHeight(x) - 1; + unsigned long rank = zslGetNodeSpanAtLevel(x, i); + while (x->level[zslGetNodeHeight(x) - 1].forward) { + x = x->level[zslGetNodeHeight(x) - 1].forward; + rank += zslGetNodeSpanAtLevel(x, zslGetNodeHeight(x) - 1); + } + rank = zsl->length - rank; + return rank; +} + /* Finds an element by its rank from start node. The rank argument needs to be 1-based. */ zskiplistNode *zslGetElementByRankFromNode(zskiplistNode *start_node, int start_level, unsigned long rank) { zskiplistNode *x; @@ -519,8 +571,8 @@ zskiplistNode *zslGetElementByRankFromNode(zskiplistNode *start_node, int start_ x = start_node; for (i = start_level; i >= 0; i--) { - while (x->level[i].forward && (traversed + x->level[i].span) <= rank) { - traversed += x->level[i].span; + while (x->level[i].forward && (traversed + zslGetNodeSpanAtLevel(x, i)) <= rank) { + traversed += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } if (traversed == rank) { @@ -690,7 +742,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) { x = zsl->header; i = zsl->level - 1; while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) { - edge_rank += x->level[i].span; + edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } /* Remember the last node which has zsl->level-1 levels and its rank. */ @@ -702,7 +754,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) { /* Go forward while *OUT* of range. */ while (x->level[i].forward && !zslLexValueGteMin(x->level[i].forward->ele, range)) { /* Count the rank of the last element smaller than the range. */ - edge_rank += x->level[i].span; + edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } } @@ -726,7 +778,7 @@ zskiplistNode *zslNthInLexRange(zskiplist *zsl, zlexrangespec *range, long n) { /* Go forward while *IN* range. */ while (x->level[i].forward && zslLexValueLteMax(x->level[i].forward->ele, range)) { /* Count the rank of the last element in range. */ - edge_rank += x->level[i].span; + edge_rank += zslGetNodeSpanAtLevel(x, i); x = x->level[i].forward; } } @@ -1173,6 +1225,13 @@ unsigned char *zzlDeleteRangeByRank(unsigned char *zl, unsigned int start, unsig * Common sorted set API *----------------------------------------------------------------------------*/ +/* Utility function used for mapping the hashtable entry to the matching skiplist node. + * For example, this is used in case of ZRANK query. */ +static inline zskiplistNode *zsetGetSLNodeByEntry(dictEntry *de) { + char *score_ref = ((char *)dictGetVal(de)); + return (zskiplistNode *)(score_ref - offsetof(zskiplistNode, score)); +} + unsigned long zsetLength(const robj *zobj) { unsigned long length = 0; if (zobj->encoding == OBJ_ENCODING_LISTPACK) { @@ -1603,8 +1662,9 @@ long zsetRank(robj *zobj, sds ele, int reverse, double *output_score) { de = dictFind(zs->dict, ele); if (de != NULL) { - score = *(double *)dictGetVal(de); - rank = zslGetRank(zsl, score, ele); + zskiplistNode *n = zsetGetSLNodeByEntry(de); + score = n->score; + rank = zslGetRankByNode(zsl, n); /* Existing elements always have a rank. */ serverAssert(rank != 0); if (output_score) *output_score = score;