/*
 * Copyright 2016, FUJITSU TECHNOLOGY SOLUTIONS GMBH
 * Copyright 2012, Armon Dadgar. All rights reserved.
 * Copyright 2016-2017, Intel Corporation
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in
 *       the documentation and/or other materials provided with the
 *       distribution.
 *
 *     * Neither the name of the copyright holder nor the names of its
 *       contributors may be used to endorse or promote products derived
 *       from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

/*
 * ==========================================================================
 *
 *     Filename:  art.c
 *
 *  Description:  implement ART tree using libvmem based on libart
 *
 *       Author:  Andreas Bluemle, Dieter Kasper
 *                Andreas.Bluemle.external@ts.fujitsu.com
 *                dieter.kasper@ts.fujitsu.com
 *
 * Organization:  FUJITSU TECHNOLOGY SOLUTIONS GMBH
 * ==========================================================================
 */

/*
 * based on https://github.com/armon/libart/src/art.c
 */
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <stdio.h>
#include <emmintrin.h>
#include <assert.h>
#include "libvmem.h"
#include "art.h"

/*
 * Macros to manipulate pointer tags
 */
#define IS_LEAF(x) (((uintptr_t)(x) & 1))
#define SET_LEAF(x) ((void *)((uintptr_t)(x) | 1))
#define LEAF_RAW(x) ((void *)((uintptr_t)(x) & ~1))

/*
 * Allocates a node of the given type,
 * initializes to zero and sets the type.
 */
static art_node *
alloc_node(VMEM *vmp, uint8_t type)
{
	art_node *n;
	switch (type) {
	case NODE4:
		n = vmem_calloc(vmp, 1, sizeof(art_node4));
		break;
	case NODE16:
		n = vmem_calloc(vmp, 1, sizeof(art_node16));
		break;
	case NODE48:
		n = vmem_calloc(vmp, 1, sizeof(art_node48));
		break;
	case NODE256:
		n = vmem_calloc(vmp, 1, sizeof(art_node256));
		break;
	default:
		abort();
	}
	assert(n != NULL);
	n->type = type;
	return n;
}

/*
 * Initializes an ART tree
 * @return 0 on success.
 */
int
art_tree_init(art_tree *t)
{
	t->root = NULL;
	t->size = 0;
	return 0;
}

/*
 * Recursively destroys the tree
 */
static void
destroy_node(VMEM *vmp, art_node *n)
{
	// Break if null
	if (!n)
		return;

	// Special case leafs
	if (IS_LEAF(n)) {
		vmem_free(vmp, LEAF_RAW(n));
		return;
	}

	// Handle each node type
	int i;
	union {
		art_node4 *p1;
		art_node16 *p2;
		art_node48 *p3;
		art_node256 *p4;
	} p;
	switch (n->type) {
	case NODE4:
		p.p1 = (art_node4 *)n;
		for (i = 0; i < n->num_children; i++) {
			destroy_node(vmp, p.p1->children[i]);
		}
		break;

	case NODE16:
		p.p2 = (art_node16 *)n;
		for (i = 0; i < n->num_children; i++) {
			destroy_node(vmp, p.p2->children[i]);
		}
		break;

	case NODE48:
		p.p3 = (art_node48 *)n;
		for (i = 0; i < n->num_children; i++) {
			destroy_node(vmp, p.p3->children[i]);
		}
		break;

	case NODE256:
		p.p4 = (art_node256 *)n;
		for (i = 0; i < 256; i++) {
			if (p.p4->children[i])
				destroy_node(vmp, p.p4->children[i]);
		}
		break;

	default:
		abort();
	}

	// Free ourself on the way up
	vmem_free(vmp, n);
}

/*
 * Destroys an ART tree
 * @return 0 on success.
 */
int
art_tree_destroy(VMEM *vmp, art_tree *t)
{
	destroy_node(vmp, t->root);
	return 0;
}

/*
 * Returns the size of the ART tree.
 */

static art_node **
find_child(art_node *n, unsigned char c)
{
	__m128i cmp;
	int i, mask, bitfield;
	union {
		art_node4 *p1;
		art_node16 *p2;
		art_node48 *p3;
		art_node256 *p4;
	} p;

	switch (n->type) {
	case NODE4:
		p.p1 = (art_node4 *)n;
		for (i = 0; i < n->num_children; i++) {
			if (p.p1->keys[i] == c)
				return &p.p1->children[i];
		}
		break;

	case NODE16:
		p.p2 = (art_node16 *)n;

		// Compare the key to all 16 stored keys
		cmp = _mm_cmpeq_epi8(_mm_set1_epi8(c),
			_mm_loadu_si128((__m128i *)p.p2->keys));

		// Use a mask to ignore children that don't exist
		mask = (1 << n->num_children) - 1;
		bitfield = _mm_movemask_epi8(cmp) & mask;

		/*
		 * If we have a match (any bit set) then we can
		 * return the pointer match using ctz to get
		 * the index.
		 */
		if (bitfield)
			return &p.p2->children[__builtin_ctz(bitfield)];
		break;

	case NODE48:
		p.p3 = (art_node48 *)n;
		i = p.p3->keys[c];
		if (i)
			return &p.p3->children[i - 1];
		break;

	case NODE256:
		p.p4 = (art_node256 *)n;
		if (p.p4->children[c])
			return &p.p4->children[c];
		break;

	default:
		abort();
	}
	return NULL;
}

// Simple inlined if
static inline int
min(int a, int b)
{
	return (a < b) ? a : b;
}

/*
 * Returns the number of prefix characters shared between
 * the key and node.
 */
static int
check_prefix(const art_node *n, const unsigned char *key, int key_len,
		int depth)
{
	int max_cmp = min(min(n->partial_len, MAX_PREFIX_LEN),
		    key_len - depth);
	int idx;
	for (idx = 0; idx < max_cmp; idx++) {
		if (n->partial[idx] != key[depth + idx])
			return idx;
	}
	return idx;
}

/*
 * Checks if a leaf matches
 * @return 0 on success.
 */
static int
leaf_matches(const art_leaf *n, const unsigned char *key, int key_len,
		int depth)
{
	(void) depth;
	// Fail if the key lengths are different
	if (n->key_len != (uint32_t)key_len)
		return 1;

	// Compare the keys starting at the depth
	return memcmp(n->key, key, key_len);
}

/*
 * Searches for a value in the ART tree
 * @arg t The tree
 * @arg key The key
 * @arg key_len The length of the key
 * @return NULL if the item was not found, otherwise
 * the value pointer is returned.
 */
void *
art_search(const art_tree *t, const unsigned char *key, int key_len)
{
	art_node **child;
	art_node *n = t->root;
	int prefix_len, depth = 0;
	while (n) {
		// Might be a leaf
		if (IS_LEAF(n)) {
			n = LEAF_RAW(n);
			// Check if the expanded path matches
			if (!leaf_matches((art_leaf *)n,
			    key, key_len, depth)) {
				return ((art_leaf *)n)->value;
			}
		return NULL;
		}

		// Bail if the prefix does not match
		if (n->partial_len) {
			prefix_len = check_prefix(n, key, key_len, depth);
			if (prefix_len != min(MAX_PREFIX_LEN,
			    n->partial_len))
				return NULL;
			depth = depth + n->partial_len;
		}

		// Recursively search
		child = find_child(n, key[depth]);
		n = (child) ? *child : NULL;
		depth++;
	}
	return NULL;
}

// Find the minimum leaf under a node
static art_leaf *
minimum(const art_node *n)
{
	// Handle base cases
	if (!n)
		return NULL;
	if (IS_LEAF(n))
		return LEAF_RAW(n);

	int idx;
	switch (n->type) {
	case NODE4:
		return minimum(((art_node4 *)n)->children[0]);
	case NODE16:
		return minimum(((art_node16 *)n)->children[0]);
	case NODE48:
		idx = 0;
		while (!((art_node48 *)n)->keys[idx])
			idx++;
		idx = ((art_node48 *)n)->keys[idx] - 1;
		assert(idx < 48);
		return minimum(((art_node48 *) n)->children[idx]);
	case NODE256:
		idx = 0;
		while (!((art_node256 *)n)->children[idx])
			idx++;
		return minimum(((art_node256 *)n)->children[idx]);
	default:
		abort();
	}
}

// Find the maximum leaf under a node
static art_leaf *
maximum(const art_node *n)
{
	// Handle base cases
	if (!n)
		return NULL;
	if (IS_LEAF(n))
		return LEAF_RAW(n);

	int idx;
	switch (n->type) {
	case NODE4:
		return maximum(
			((art_node4 *)n)->children[n->num_children - 1]);
	case NODE16:
		return maximum(
			((art_node16 *)n)->children[n->num_children - 1]);
	case NODE48:
		idx = 255;
		while (!((art_node48 *)n)->keys[idx])
			idx--;
		idx = ((art_node48 *)n)->keys[idx] - 1;
		assert((idx >= 0) && (idx < 48));
		return maximum(((art_node48 *)n)->children[idx]);
	case NODE256:
		idx = 255;
		while (!((art_node256 *)n)->children[idx])
			idx--;
		return maximum(((art_node256 *)n)->children[idx]);
	default:
		abort();
	}
}

/*
 * Returns the minimum valued leaf
 */
art_leaf *
art_minimum(art_tree *t)
{
	return minimum(t->root);
}

/*
 * Returns the maximum valued leaf
 */
art_leaf *
art_maximum(art_tree *t)
{
	return maximum(t->root);
}

static art_leaf *
make_leaf(VMEM *vmp, const unsigned char *key, int key_len, void *value,
	    int val_len)
{
	art_leaf *l = vmem_malloc(vmp, sizeof(art_leaf) + key_len + val_len);
	assert(l != NULL);
	l->key_len = key_len;
	l->val_len = val_len;
	l->key = &(l->data[0]) + 0;
	l->value = &(l->data[0]) + key_len;
	memcpy(l->key, key, key_len);
	memcpy(l->value, value, val_len);
	return l;
}

static int
longest_common_prefix(art_leaf *l1, art_leaf *l2, int depth)
{
	int max_cmp = min(l1->key_len, l2->key_len) - depth;
	int idx;
	for (idx = 0; idx < max_cmp; idx++) {
		if (l1->key[depth + idx] != l2->key[depth + idx])
			return idx;
	}
	return idx;
}

static void
copy_header(art_node *dest, art_node *src)
{
	dest->num_children = src->num_children;
	dest->partial_len = src->partial_len;
	memcpy(dest->partial, src->partial,
	    min(MAX_PREFIX_LEN, src->partial_len));
}

static void
add_child256(VMEM *vmp, art_node256 *n, art_node **ref, unsigned char c,
		void *child)
{
	(void) ref;
	n->n.num_children++;
	n->children[c] = child;
}

static void
add_child48(VMEM *vmp, art_node48 *n, art_node **ref, unsigned char c,
		void *child)
{
	if (n->n.num_children < 48) {
		int pos = 0;
		while (n->children[pos])
			pos++;
		n->children[pos] = child;
		n->keys[c] = pos + 1;
		n->n.num_children++;
	} else {
		art_node256 *new = (art_node256 *)alloc_node(vmp, NODE256);
		for (int i = 0; i < 256; i++) {
			if (n->keys[i]) {
				new->children[i] = n->children[n->keys[i] - 1];
			}
		}
		copy_header((art_node *)new, (art_node *)n);
		*ref = (art_node *)new;
		vmem_free(vmp, n);
		add_child256(vmp, new, ref, c, child);
	}
}

static void
add_child16(VMEM *vmp, art_node16 *n, art_node **ref, unsigned char c,
		void *child)
{
	if (n->n.num_children < 16) {
		__m128i cmp;

		// Compare the key to all 16 stored keys
		cmp = _mm_cmplt_epi8(_mm_set1_epi8(c),
			    _mm_loadu_si128((__m128i *)n->keys));

		// Use a mask to ignore children that don't exist
		unsigned mask = (1 << n->n.num_children) - 1;
		unsigned bitfield = _mm_movemask_epi8(cmp) & mask;

		// Check if less than any
		unsigned idx;
		if (bitfield) {
			idx = __builtin_ctz(bitfield);
			memmove(n->keys + idx + 1, n->keys + idx,
			    n->n.num_children - idx);
			memmove(n->children + idx + 1, n->children + idx,
			    (n->n.num_children - idx) * sizeof(void *));
		} else
			idx = n->n.num_children;

		// Set the child
		n->keys[idx] = c;
		n->children[idx] = child;
		n->n.num_children++;
	} else {
		art_node48 *new = (art_node48 *)alloc_node(vmp, NODE48);

		// Copy the child pointers and populate the key map
		memcpy(new->children, n->children,
		    sizeof(void *) * n->n.num_children);
		for (int i = 0; i < n->n.num_children; i++) {
			new->keys[n->keys[i]] = i + 1;
		}
		copy_header((art_node *)new, (art_node *)n);
		*ref = (art_node *) new;
		vmem_free(vmp, n);
		add_child48(vmp, new, ref, c, child);
	}
}

static void
add_child4(VMEM *vmp, art_node4 *n, art_node **ref, unsigned char c,
		void *child)
{
	if (n->n.num_children < 4) {
		int idx;
		for (idx = 0; idx < n->n.num_children; idx++) {
			if (c < n->keys[idx])
				break;
		}

		// Shift to make room
		memmove(n->keys + idx + 1, n->keys + idx,
		    n->n.num_children - idx);
		memmove(n->children + idx + 1, n->children + idx,
		    (n->n.num_children - idx) * sizeof(void *));

		// Insert element
		n->keys[idx] = c;
		n->children[idx] = child;
		n->n.num_children++;
	} else {
		art_node16 *new = (art_node16 *)alloc_node(vmp, NODE16);

		// Copy the child pointers and the key map
		memcpy(new->children, n->children,
		    sizeof(void *) * n->n.num_children);
		memcpy(new->keys, n->keys,
		    sizeof(unsigned char) * n->n.num_children);
		copy_header((art_node *)new, (art_node *)n);
		*ref = (art_node *)new;
		vmem_free(vmp, n);
		add_child16(vmp, new, ref, c, child);
	}
}

static void
add_child(VMEM *vmp, art_node *n, art_node **ref, unsigned char c,
		void *child)
{
	switch (n->type) {
	case NODE4:
		return add_child4(vmp, (art_node4 *)n, ref, c, child);
	case NODE16:
		return add_child16(vmp, (art_node16 *)n, ref, c, child);
	case NODE48:
		return add_child48(vmp, (art_node48 *)n, ref, c, child);
	case NODE256:
		return add_child256(vmp, (art_node256 *)n, ref, c, child);
	default:
		abort();
	}
}

/*
 * Calculates the index at which the prefixes mismatch
 */
static int
prefix_mismatch(const art_node *n, const unsigned char *key, int key_len,
		int depth)
{
	int max_cmp = min(min(MAX_PREFIX_LEN, n->partial_len), key_len - depth);
	int idx;
	for (idx = 0; idx < max_cmp; idx++) {
		if (n->partial[idx] != key[depth + idx])
			return idx;
	}

	// If the prefix is short we can avoid finding a leaf
	if (n->partial_len > MAX_PREFIX_LEN) {
		// Prefix is longer than what we've checked, find a leaf
		art_leaf *l = minimum(n);
		assert(l != NULL);
		max_cmp = min(l->key_len, key_len) - depth;
		for (; idx < max_cmp; idx++) {
			if (l->key[idx + depth] != key[depth + idx])
				return idx;
		}
	}
	return idx;
}

static void *
recursive_insert(VMEM *vmp, art_node *n, art_node **ref,
		const unsigned char *key, int key_len, void *value,
		int val_len, int depth, int *old)
{
	// If we are at a NULL node, inject a leaf
	if (!n) {
		*ref = (art_node *)SET_LEAF(
			    make_leaf(vmp, key, key_len, value, val_len));
		return NULL;
	}

	// If we are at a leaf, we need to replace it with a node
	if (IS_LEAF(n)) {
		art_leaf *l = LEAF_RAW(n);

		// Check if we are updating an existing value
		if (!leaf_matches(l, key, key_len, depth)) {
			*old = 1;
			void *old_val = l->value;
			l->value = value;
			return old_val;
		}

		// New value, we must split the leaf into a node4
		art_node4 *new = (art_node4 *)alloc_node(vmp, NODE4);

		// Create a new leaf
		art_leaf *l2 = make_leaf(vmp, key, key_len, value, val_len);

		// Determine longest prefix
		int longest_prefix = longest_common_prefix(l, l2, depth);
		new->n.partial_len = longest_prefix;
		memcpy(new->n.partial, key + depth,
		    min(MAX_PREFIX_LEN, longest_prefix));
		// Add the leafs to the new node4
		*ref = (art_node *)new;
		add_child4(vmp, new, ref, l->key[depth + longest_prefix],
		    SET_LEAF(l));
		add_child4(vmp, new, ref, l2->key[depth + longest_prefix],
		    SET_LEAF(l2));
		return NULL;
	}

	// Check if given node has a prefix
	if (n->partial_len) {
		// Determine if the prefixes differ, since we need to split
		int prefix_diff = prefix_mismatch(n, key, key_len, depth);
		if ((uint32_t)prefix_diff >= n->partial_len) {
			depth += n->partial_len;
			goto RECURSE_SEARCH;
		}

		// Create a new node
		art_node4 *new = (art_node4 *)alloc_node(vmp, NODE4);
		*ref = (art_node *)new;
		new->n.partial_len = prefix_diff;
		memcpy(new->n.partial, n->partial,
		    min(MAX_PREFIX_LEN, prefix_diff));

		// Adjust the prefix of the old node
		if (n->partial_len <= MAX_PREFIX_LEN) {
			add_child4(vmp, new, ref,
			    n->partial[prefix_diff], n);
			n->partial_len -= (prefix_diff + 1);
			memmove(n->partial, n->partial + prefix_diff + 1,
			    min(MAX_PREFIX_LEN, n->partial_len));
		} else {
			n->partial_len -= (prefix_diff + 1);
			art_leaf *l = minimum(n);
			assert(l != NULL);
			add_child4(vmp, new, ref,
			    l->key[depth + prefix_diff], n);
			memcpy(n->partial, l->key + depth + prefix_diff + 1,
			    min(MAX_PREFIX_LEN, n->partial_len));
		}

		// Insert the new leaf
		art_leaf *l = make_leaf(vmp, key, key_len, value, val_len);
		add_child4(vmp, new, ref,
		    key[depth + prefix_diff], SET_LEAF(l));
		return NULL;
	}

RECURSE_SEARCH:;

	// Find a child to recurse to
	art_node **child = find_child(n, key[depth]);
	if (child) {
		return recursive_insert(vmp, *child, child, key, key_len,
			    value, val_len, depth + 1, old);
	}

	// No child, node goes within us
	art_leaf *l = make_leaf(vmp, key, key_len, value, val_len);
	add_child(vmp, n, ref, key[depth], SET_LEAF(l));
	return NULL;
}

/*
 * Inserts a new value into the ART tree
 * @arg t The tree
 * @arg key The key
 * @arg key_len The length of the key
 * @arg value Opaque value.
 * @return NULL if the item was newly inserted, otherwise
 * the old value pointer is returned.
 */
void *
art_insert(VMEM *vmp, art_tree *t, const unsigned char *key, int key_len,
		void *value, int val_len)
{
	int old_val = 0;
	void *old = recursive_insert(vmp, t->root, &t->root, key, key_len,
	    value, val_len, 0, &old_val);
	if (!old_val)
		t->size++;
	return old;
}

static void
remove_child256(VMEM *vmp, art_node256 *n, art_node **ref,
		unsigned char c)
{
	n->children[c] = NULL;
	n->n.num_children--;

	// Resize to a node48 on underflow, not immediately to prevent
	// trashing if we sit on the 48/49 boundary
	if (n->n.num_children == 37) {
		art_node48 *new = (art_node48 *)alloc_node(vmp, NODE48);
		*ref = (art_node *) new;
		copy_header((art_node *)new, (art_node *)n);

		int pos = 0;
		for (int i = 0; i < 256; i++) {
			if (n->children[i]) {
				assert(pos < 48);
				new->children[pos] = n->children[i];
				new->keys[i] = pos + 1;
				pos++;
			}
		}
		vmem_free(vmp, n);
	}
}

static void
remove_child48(VMEM *vmp, art_node48 *n, art_node **ref, unsigned char c)
{
	int pos = n->keys[c];
	n->keys[c] = 0;
	n->children[pos - 1] = NULL;
	n->n.num_children--;

	if (n->n.num_children == 12) {
		art_node16 *new = (art_node16 *)alloc_node(vmp, NODE16);
		*ref = (art_node *)new;
		copy_header((art_node *) new, (art_node *)n);

		int child = 0;
		for (int i = 0; i < 256; i++) {
			pos = n->keys[i];
			if (pos) {
				assert(child < 16);
				new->keys[child] = i;
				new->children[child] = n->children[pos - 1];
				child++;
			}
		}
		vmem_free(vmp, n);
	}
}

static void
remove_child16(VMEM *vmp, art_node16 *n, art_node **ref, art_node **l)
{
	int pos = l - n->children;
	memmove(n->keys + pos, n->keys + pos + 1, n->n.num_children - 1 - pos);
	memmove(n->children + pos, n->children + pos + 1,
	    (n->n.num_children - 1 - pos) * sizeof(void *));
	n->n.num_children--;

	if (n->n.num_children == 3) {
		art_node4 *new = (art_node4 *)alloc_node(vmp, NODE4);
		*ref = (art_node *) new;
		copy_header((art_node *)new, (art_node *)n);
		memcpy(new->keys, n->keys, 4);
		memcpy(new->children, n->children, 4 * sizeof(void *));
		vmem_free(vmp, n);
	}
}

static void
remove_child4(VMEM *vmp, art_node4 *n, art_node **ref, art_node **l)
{
	int pos = l - n->children;
	memmove(n->keys + pos, n->keys + pos + 1, n->n.num_children - 1 - pos);
	memmove(n->children + pos, n->children + pos + 1,
	    (n->n.num_children - 1 - pos) * sizeof(void *));
	n->n.num_children--;

	// Remove nodes with only a single child
	if (n->n.num_children == 1) {
		art_node *child = n->children[0];
		if (!IS_LEAF(child)) {
			// Concatenate the prefixes
			int prefix = n->n.partial_len;
			if (prefix < MAX_PREFIX_LEN) {
				n->n.partial[prefix] = n->keys[0];
				prefix++;
			}
			if (prefix < MAX_PREFIX_LEN) {
				int sub_prefix =
					min(child->partial_len,
					    MAX_PREFIX_LEN - prefix);
				memcpy(n->n.partial + prefix,
				    child->partial, sub_prefix);
				prefix += sub_prefix;
			}

			// Store the prefix in the child
			memcpy(child->partial, n->n.partial,
			    min(prefix, MAX_PREFIX_LEN));
			child->partial_len += n->n.partial_len + 1;
		}
		*ref = child;
		vmem_free(vmp, n);
	}
}

static void
remove_child(VMEM *vmp, art_node *n, art_node **ref, unsigned char c,
		art_node **l)
{
	switch (n->type) {
	case NODE4:
		return remove_child4(vmp, (art_node4 *)n, ref, l);
	case NODE16:
		return remove_child16(vmp, (art_node16 *)n, ref, l);
	case NODE48:
		return remove_child48(vmp, (art_node48 *)n, ref, c);
	case NODE256:
		return remove_child256(vmp, (art_node256 *)n, ref, c);
	default:
		abort();
	}
}

static art_leaf *
recursive_delete(VMEM *vmp, art_node *n, art_node **ref,
		const unsigned char *key, int key_len, int depth)
{
	// Search terminated
	if (!n)
		return NULL;

	// Handle hitting a leaf node
	if (IS_LEAF(n)) {
		art_leaf *l = LEAF_RAW(n);
		if (!leaf_matches(l, key, key_len, depth)) {
			*ref = NULL;
			return l;
		}
		return NULL;
	}

	// Bail if the prefix does not match
	if (n->partial_len) {
		int prefix_len = check_prefix(n, key, key_len, depth);
		if (prefix_len != min(MAX_PREFIX_LEN, n->partial_len)) {
			return NULL;
		}
		depth = depth + n->partial_len;
	}

	// Find child node
	art_node **child = find_child(n, key[depth]);
	if (!child)
		return NULL;

	// If the child is leaf, delete from this node
	if (IS_LEAF(*child)) {
		art_leaf *l = LEAF_RAW(*child);
		if (!leaf_matches(l, key, key_len, depth)) {
			remove_child(vmp, n, ref, key[depth], child);
			return l;
		}
		return NULL;

		// Recurse
	} else {
		return recursive_delete(vmp, *child, child, key,
			    key_len, depth + 1);
	}
}

/*
 * Deletes a value from the ART tree
 * @arg t The tree
 * @arg key The key
 * @arg key_len The length of the key
 * @return NULL if the item was not found, otherwise
 * the value pointer is returned.
 */
void *
art_delete(VMEM *vmp, art_tree *t, const unsigned char *key, int key_len)
{
	art_leaf *l = recursive_delete(vmp, t->root, &t->root, key, key_len, 0);
	if (l) {
		t->size--;
		void *old = l->value;
		vmem_free(vmp, l);
		return old;
	}
	return NULL;
}

// Recursively iterates over the tree
static int
recursive_iter(art_node *n, art_callback cb, void *data)
{
	// Handle base cases
	if (!n)
		return 0;
	if (IS_LEAF(n)) {
		art_leaf *l = LEAF_RAW(n);
		return cb(data, (const unsigned char *)l->key, l->key_len,
		    l->value, l->val_len);
	}

	int idx, res;
	switch (n->type) {
	case NODE4:
		for (int i = 0; i < n->num_children; i++) {
			res = recursive_iter(((art_node4 *)n)->children[i],
				    cb, data);
			if (res)
				return res;
		}
		break;

	case NODE16:
		for (int i = 0; i < n->num_children; i++) {
			res = recursive_iter(
				    ((art_node16 *)n)->children[i],
				    cb, data);
			if (res)
				return res;
		}
		break;

	case NODE48:
		for (int i = 0; i < 256; i++) {
			idx = ((art_node48 *)n)->keys[i];
			if (!idx)
				continue;

			res = recursive_iter(
				    ((art_node48 *)n)->children[idx - 1],
				    cb, data);
			if (res)
				return res;
		}
		break;

	case NODE256:
		for (int i = 0; i < 256; i++) {
			if (!((art_node256 *)n)->children[i])
				continue;
			res = recursive_iter(
				    ((art_node256 *)n)->children[i],
				    cb, data);
			if (res)
				return res;
		}
		break;

	default:
		abort();
	}
	return 0;
}

/*
 * Iterates through the entries pairs in the map,
 * invoking a callback for each. The call back gets a
 * key, value for each and returns an integer stop value.
 * If the callback returns non-zero, then the iteration stops.
 * @arg t The tree to iterate over
 * @arg cb The callback function to invoke
 * @arg data Opaque handle passed to the callback
 * @return 0 on success, or the return of the callback.
 */
int
art_iter(art_tree *t, art_callback cb, void *data)
{
	return recursive_iter(t->root, cb, data);
}

// Recursively iterates over the tree
static int
recursive_iter2(art_node *n, art_callback cb, void *data)
{
	cb_data _cbd, *cbd = &_cbd;
	int first = 1;

	// Handle base cases
	if (!n)
		return 0;
	cbd->node = (void *)n;
	cbd->node_type = n->type;
	cbd->child_idx = -1;

	if (IS_LEAF(n)) {
		art_leaf *l = LEAF_RAW(n);
		return cb(cbd, (const unsigned char *)l->key,
			    l->key_len, l->value, l->val_len);
	}

	int idx, res;
	switch (n->type) {
	case NODE4:
		for (int i = 0; i < n->num_children; i++) {
			cbd->first_child = first;
			first = 0;
			cbd->child_idx = i;
			cb((void *)cbd, NULL, 0, NULL, 0);
			res = recursive_iter2(((art_node4 *)n)->children[i],
				    cb, data);
			if (res)
				return res;
		}
		break;

	case NODE16:
		for (int i = 0; i < n->num_children; i++) {
			cbd->first_child = first;
			first = 0;
			cbd->child_idx = i;
			cb((void *)cbd, NULL, 0, NULL, 0);
			res = recursive_iter2(((art_node16 *)n)->children[i],
				    cb, data);
			if (res)
				return res;
		}
		break;

	case NODE48:
		for (int i = 0; i < 256; i++) {
			idx = ((art_node48 *)n)->keys[i];
			if (!idx)
				continue;

			cbd->first_child = first;
			first = 0;
			cbd->child_idx = i;
			cb((void *)cbd, NULL, 0, NULL, 0);
			res = recursive_iter2(
				    ((art_node48 *)n)->children[idx - 1],
				    cb, data);
			if (res)
				return res;
		}
		break;

	case NODE256:
		for (int i = 0; i < 256; i++) {
			if (!((art_node256 *)n)->children[i])
				continue;
			cbd->first_child = first;
			first = 0;
			cbd->child_idx = i;
			cb((void *)cbd, NULL, 0, NULL, 0);
			res = recursive_iter2(
				    ((art_node256 *)n)->children[i],
				    cb, data);
			if (res)
				return res;
		}
		break;

	default:
		abort();
	}
	return 0;
}

/*
 * Iterates through the entries pairs in the map,
 * invoking a callback for each. The call back gets a
 * key, value for each and returns an integer stop value.
 * If the callback returns non-zero, then the iteration stops.
 * @arg t The tree to iterate over
 * @arg cb The callback function to invoke
 * @arg data Opaque handle passed to the callback
 * @return 0 on success, or the return of the callback.
 */
int
art_iter2(art_tree *t, art_callback cb, void *data)
{
	return recursive_iter2(t->root, cb, data);
}

/*
 * Checks if a leaf prefix matches
 * @return 0 on success.
 */
static int
leaf_prefix_matches(const art_leaf *n, const unsigned char *prefix,
		int prefix_len)
{
	// Fail if the key length is too short
	if (n->key_len < (uint32_t)prefix_len)
		return 1;

	// Compare the keys
	return memcmp(n->key, prefix, prefix_len);
}

/*
 * Iterates through the entries pairs in the map,
 * invoking a callback for each that matches a given prefix.
 * The call back gets a key, value for each and returns an integer stop value.
 * If the callback returns non-zero, then the iteration stops.
 * @arg t The tree to iterate over
 * @arg prefix The prefix of keys to read
 * @arg prefix_len The length of the prefix
 * @arg cb The callback function to invoke
 * @arg data Opaque handle passed to the callback
 * @return 0 on success, or the return of the callback.
 */
int
art_iter_prefix(art_tree *t, const unsigned char *key, int key_len,
		art_callback cb, void *data)
{
	art_node **child;
	art_node *n = t->root;
	int prefix_len, depth = 0;
	while (n) {
		// Might be a leaf
		if (IS_LEAF(n)) {
			n = LEAF_RAW(n);
			// Check if the expanded path matches
			if (!leaf_prefix_matches(
				    (art_leaf *)n, key, key_len)) {
				art_leaf *l = (art_leaf *)n;
				return cb(data, (const unsigned char *)l->key,
				    l->key_len, l->value, l->val_len);
			}
			return 0;
		}

		// If the depth matches the prefix, we need to handle this node
		if (depth == key_len) {
			art_leaf *l = minimum(n);
			assert(l != NULL);
			if (!leaf_prefix_matches(l, key, key_len))
				return recursive_iter(n, cb, data);
			return 0;
		}

		// Bail if the prefix does not match
		if (n->partial_len) {
			prefix_len = prefix_mismatch(n, key, key_len, depth);

			// If there is no match, search is terminated
			if (!prefix_len)
				return 0;

			// If we've matched the prefix, iterate on this node
			else if (depth + prefix_len == key_len) {
				return recursive_iter(n, cb, data);
			}

			// if there is a full match, go deeper
			depth = depth + n->partial_len;
		}

		// Recursively search
		child = find_child(n, key[depth]);
		n = (child) ? *child : NULL;
		depth++;
	}
	return 0;
}
