[BUG] ebtree: fix ebmb_lookup() with len smaller than the tree's keys
(from ebtree 6.0.5)
ebmb_lookup() is used by ebst_lookup_len() to lookup a string starting
with a known substring. Since the substring does not necessarily end
with a zero, we must absolutely ensure that the comparison stops at
<len> bytes, otherwise we can end up comparing crap and most often
returning the wrong node in case of multiple matches.
ebim_lookup() was fixed too by resyncing it with ebmb_lookup().
(cherry picked from commit 98eba315aa2c3285181375d312bcb770f058fd2b)
This should be backported to 1.4 though it's not critical there.
diff --git a/ebtree/ebimtree.h b/ebtree/ebimtree.h
index 4d2eea0..205a2db 100644
--- a/ebtree/ebimtree.h
+++ b/ebtree/ebimtree.h
@@ -1,7 +1,7 @@
/*
* Elastic Binary Trees - macros for Indirect Multi-Byte data nodes.
- * Version 6.0
- * (C) 2002-2010 - Willy Tarreau <w@1wt.eu>
+ * Version 6.0.5
+ * (C) 2002-2011 - Willy Tarreau <w@1wt.eu>
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
@@ -32,30 +32,37 @@
REGPRM3 struct ebpt_node *ebim_lookup(struct eb_root *root, const void *x, unsigned int len);
REGPRM3 struct ebpt_node *ebim_insert(struct eb_root *root, struct ebpt_node *new, unsigned int len);
-/* Find the first occurence of a key of <len> bytes in the tree <root>.
- * If none can be found, return NULL.
+/* Find the first occurence of a key of a least <len> bytes matching <x> in the
+ * tree <root>. The caller is responsible for ensuring that <len> will not exceed
+ * the common parts between the tree's keys and <x>. In case of multiple matches,
+ * the leftmost node is returned. This means that this function can be used to
+ * lookup string keys by prefix if all keys in the tree are zero-terminated. If
+ * no match is found, NULL is returned. Returns first node if <len> is zero.
*/
static forceinline struct ebpt_node *
__ebim_lookup(struct eb_root *root, const void *x, unsigned int len)
{
struct ebpt_node *node;
eb_troot_t *troot;
- int bit;
+ int pos, side;
int node_bit;
troot = root->b[EB_LEFT];
if (unlikely(troot == NULL))
return NULL;
- bit = 0;
+ if (unlikely(len == 0))
+ goto walk_down;
+
+ pos = 0;
while (1) {
- if ((eb_gettag(troot) == EB_LEAF)) {
+ if (eb_gettag(troot) == EB_LEAF) {
node = container_of(eb_untag(troot, EB_LEAF),
struct ebpt_node, node.branches);
- if (memcmp(node->key, x, len) == 0)
- return node;
- else
+ if (memcmp(node->key + pos, x, len) != 0)
return NULL;
+ else
+ return node;
}
node = container_of(eb_untag(troot, EB_NODE),
struct ebpt_node, node.branches);
@@ -66,10 +73,11 @@
* value, and we walk down left, or it's a different
* one and we don't have our key.
*/
- if (memcmp(node->key, x, len) != 0)
+ if (memcmp(node->key + pos, x, len) != 0)
return NULL;
-
+ walk_left:
troot = node->node.branches.b[EB_LEFT];
+ walk_down:
while (eb_gettag(troot) != EB_LEAF)
troot = (eb_untag(troot, EB_NODE))->b[EB_LEFT];
node = container_of(eb_untag(troot, EB_LEAF),
@@ -77,13 +85,38 @@
return node;
}
- /* OK, normal data node, let's walk down */
- bit = equal_bits(x, node->key, bit, node_bit);
- if (bit < node_bit)
- return NULL; /* no more common bits */
+ /* OK, normal data node, let's walk down. We check if all full
+ * bytes are equal, and we start from the last one we did not
+ * completely check. We stop as soon as we reach the last byte,
+ * because we must decide to go left/right or abort.
+ */
+ node_bit = ~node_bit + (pos << 3) + 8; // = (pos<<3) + (7 - node_bit)
+ if (node_bit < 0) {
+ /* This surprizing construction gives better performance
+ * because gcc does not try to reorder the loop. Tested to
+ * be fine with 2.95 to 4.2.
+ */
+ while (1) {
+ if (*(unsigned char*)(node->key + pos++) ^ *(unsigned char*)(x++))
+ return NULL; /* more than one full byte is different */
+ if (--len == 0)
+ goto walk_left; /* return first node if all bytes matched */
+ node_bit += 8;
+ if (node_bit >= 0)
+ break;
+ }
+ }
- troot = node->node.branches.b[(((unsigned char*)x)[node_bit >> 3] >>
- (~node_bit & 7)) & 1];
+ /* here we know that only the last byte differs, so node_bit < 8.
+ * We have 2 possibilities :
+ * - more than the last bit differs => return NULL
+ * - walk down on side = (x[pos] >> node_bit) & 1
+ */
+ side = *(unsigned char *)x >> node_bit;
+ if (((*(unsigned char*)(node->key + pos) >> node_bit) ^ side) > 1)
+ return NULL;
+ side &= 1;
+ troot = node->node.branches.b[side];
}
}