/*
 * DSI utilities
 *
 * Copyright (C) 2005-2023 Sebastiano Vigna
 *
 * This program and the accompanying materials are made available under the
 * terms of the GNU Lesser General Public License v2.1 or later,
 * which is available at
 * http://www.gnu.org/licenses/old-licenses/lgpl-2.1-standalone.html,
 * or the Apache Software License 2.0, which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.
 *
 * SPDX-License-Identifier: LGPL-2.1-or-later OR Apache-2.0
 */

package it.unimi.dsi.util;

import static it.unimi.dsi.bits.LongArrayBitVector.word;

import java.io.Serializable;
import java.util.Iterator;
import java.util.ListIterator;

import it.unimi.dsi.bits.BitVector;
import it.unimi.dsi.bits.LongArrayBitVector;
import it.unimi.dsi.bits.TransformationStrategy;
import it.unimi.dsi.fastutil.booleans.BooleanIterator;
import it.unimi.dsi.fastutil.objects.AbstractObject2LongFunction;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import it.unimi.dsi.fastutil.objects.ObjectList;
import it.unimi.dsi.lang.MutableString;


/** An immutable implementation of binary tries.
 *
 * <P>Instance of this class are built starting from a lexicographically ordered
 * list of {@link BitVector}s representing binary words. Each word
 * is assigned its position (starting from 0) in the list. The words are then organised in a
 * binary trie with path compression.
 *
 * <p>Once the trie has been
 * built, it is possible to ask whether a word <var>w</var> is {@linkplain #get(BooleanIterator) contained in the trie}
 * (getting back its position in the list), the {@linkplain #getInterval(BooleanIterator) interval given by the words extending <var>w</var>} and the
 * {@linkplain #getApproximatedInterval(BooleanIterator) approximated interval defined by <var>w</var>}.

 * @author Sebastiano Vigna
 * @since 0.9.2
 */

public class ImmutableBinaryTrie<T> extends AbstractObject2LongFunction<T> implements Serializable {
	private static final long serialVersionUID = 1L;

	/** A node in the trie. */
	protected static class Node implements Serializable {
		private static final long serialVersionUID = 1L;
		public Node left, right;
		/** An array containing the path compacted in this node ({@code null} if there is no compaction at this node). */
		public final long[] path;
		/** The length of the path compacted in this node (0 if there is no compaction at this node). */
		public final int pathLength;
		/** If nonnegative, this node represent the <code>word</code>-th word. */
		public final int word ;

		/** Creates a node representing a word.
		 *
		 * <p>Note that the long array contained in <code>path</code> will be stored inside the node.
		 *
		 * @param path the path compacted in this node, or {@code null} for the empty path.
		 * @param word the index of the word represented by this node.
		 */

		public Node(final BitVector path, final int word) {
			if (path == null) {
				this.path = null;
				this.pathLength = 0;
			}
			else {
				this.path = path.bits();
				this.pathLength = (int) path.length();
			}
			this.word = word;
		}

		/** Creates a node that does not represent a word.
		 *
		 * @param path the path compacted in this node, or {@code null} for the empty path.
		 */
		public Node(final BitVector path) {
			this(path, -1);
		}


		/** Returns true if this node is a leaf.
		 *
		 * @return  true if this node is a leaf.
		 */
		public boolean isLeaf() {
			return right == null && left == null;
		}

		@Override
		public String toString() {
			return "[" + path + ", " + word + "]";
		}

	}

	/** The root of the trie. */
	protected final Node root;
	/** The number of words in this trie. */
	private int size;
	private final TransformationStrategy<? super T> transformationStrategy;

	private static boolean get(final long[] a, final long i) {
		return (a[word(i)] & (1L << i)) != 0;
	}


	/** Creates a trie from a set of elements.
	 *
	 * @param elements a set of elements
	 * @param transformationStrategy a transformation strategy that must turn <code>elements</code> into a list of
	 * distinct, lexicographically increasing (in iteration order) binary words.
	 */

	public ImmutableBinaryTrie(final Iterable<? extends T> elements, final TransformationStrategy<? super T> transformationStrategy) {
		this.transformationStrategy = transformationStrategy;
		defRetValue = -1;
		// Check order
		final Iterator<? extends T> iterator = elements.iterator();
		final ObjectList<LongArrayBitVector> words = new ObjectArrayList<>();
		int cmp;
		if (iterator.hasNext()) {
			final LongArrayBitVector prev = LongArrayBitVector.copy(transformationStrategy.toBitVector(iterator.next()));
			words.add(prev.copy());
			BitVector curr;

			while(iterator.hasNext()) {
				curr = transformationStrategy.toBitVector(iterator.next());
				cmp = prev.compareTo(curr);
				if (cmp == 0) throw new IllegalArgumentException("The trie elements are not unique");
				if (cmp > 0) throw new IllegalArgumentException("The trie elements are not sorted");
				prev.replace(curr);
				words.add(prev.copy());
			}
		}
		root = buildTrie(words, 0);
	}

	/** Builds a trie recursively.
	 *
	 * <p>The trie will contain the suffixes of words in <code>words</code> starting at <code>pos</code>.
	 *
	 * @param elements a list of elements.
	 * @param pos a starting position.
	 * @return a trie containing the suffixes of words in <code>words</code> starting at <code>pos</code>.
	 */

	protected Node buildTrie(final ObjectList<LongArrayBitVector> elements, final int pos) {
		// TODO: on-the-fly check for lexicographical order

		if (elements.size() == 0) return null;

		final BitVector first = elements.get(0);
		BitVector curr;
		int prefix = (int) first.length(), change = -1, j;

		// We rule out the case of a single word (it would work anyway, but it's faster)
		if (elements.size() == 1) return new Node(pos < prefix ? LongArrayBitVector.copy(first.subVector(pos, prefix)) : null, size++);

		// 	Find maximum common prefix. change records the point of change (for splitting the word set).
		for(final ListIterator<LongArrayBitVector> i = elements.listIterator(1); i.hasNext();) {
			curr = i.next();

			if (curr.length() < prefix) prefix = (int) curr.length();
			for(j = pos; j < prefix; j++) if (first.getBoolean(j) != curr.getBoolean(j)) break;
			if (j < prefix) {
				change = i.previousIndex();
				prefix = j;
			}
		}

		final Node n;
		if (prefix == first.length()) {
			// Special case: the first word is the common prefix. We must store it in the node,
			// and explicitly search for the actual change point, which is the first
			// word with prefix-th bit true.
			change = 1;
			for(final ListIterator<LongArrayBitVector> i = elements.listIterator(1); i.hasNext();) {
				curr = i.next();
				if (curr.getBoolean(prefix)) break;
				change++;
			}

			n = new Node(prefix > pos ? LongArrayBitVector.copy(first.subVector(pos, prefix)) : null, size++);
			n.left = buildTrie(elements.subList(1, change), prefix + 1);
			n.right = buildTrie(elements.subList(change, elements.size()), prefix + 1);
		}
		else {
			n = new Node(prefix > pos ? LongArrayBitVector.copy(first.subVector(pos, prefix)) : null); // There's some common prefix
			n.left = buildTrie(elements.subList(0, change), prefix + 1);
			n.right = buildTrie(elements.subList(change, elements.size()), prefix + 1);
		}
		return n;
	}

	/** Returns the number of binary words in this trie.
	 *
	 * @return the number of binary words in this trie.
	 */

	@Override
	public int size() {
		return size;
	}

	@SuppressWarnings("unchecked")
	public long getIndex(final Object element) {
		final BitVector word = transformationStrategy.toBitVector((T)element);
		final int length = (int) word.length();
		Node n = root;

		int pos = 0; // Current position in word
		long[] path;

		while(n != null) {
			if (pos == length) return n.word;

			path = n.path;
			if (path != null) {
				final int minLength = Math.min(length - pos, n.pathLength);
				int i;
				for(i = 0; i < minLength; i++) if (word.getBoolean(pos + i) != get(path, i)) break;
				// Incompatible with current path.
				if (i < minLength) return -1;

				pos += i;

				// Completely contained in the current path (note that n.word == -1 if this is not a word).
				if (pos == length) return n.word;
			}

			n = word.getBoolean(pos++) ? n.right : n.left;
		}

		return -1;
	}

	@Override
	public long getLong(final Object element) {
		final long result = getIndex(element);
		return result == -1 ? defRetValue : result;
	}

	@Override
	public boolean containsKey(final Object element) {
		return getIndex(element) != -1;
	}


	/** Return the index of the word returned by the given iterator, or -1 if the word is not in this trie.
	 *
	 * @param iterator a boolean iterator that will be used to find a word in this trie.
	 * @return the index of the specified word, or -1 if the word returned by the iterator is not in this trie.
	 * @see #getLong(Object)
	 */

	public int get(final BooleanIterator iterator) {
		Node n = root;
		int pathLength;
		long[] path;

		while(n != null) {
			if (! iterator.hasNext()) return n.word;

			pathLength = n.pathLength;

			if (pathLength != 0) {
				int i;
				path = n.path;
				for(i = 0; i < pathLength && iterator.hasNext(); i++) if (iterator.nextBoolean() != get(path, i)) break;
				// Incompatible with current path.
				if (i < pathLength) return -1;

				// Completely contained in the current path (note that n.word == -1 if this is not a word).
				if (! iterator.hasNext()) return n.word;
			}

			n = iterator.nextBoolean() ? n.right : n.left;

		}

		return -1;
	}

	/** Returns an interval given by the smallest and the largest word in the trie starting with the specified word.
	 *
	 * @param word a word.
	 * @return  an interval given by the smallest and the largest word in the trie
	 * that start with <code>word</code> (thus, the {@linkplain Intervals#EMPTY_INTERVAL empty inteval}
	 * if no such words exist).
	 * @see #getInterval(BooleanIterator)
	 */

	public Interval getInterval(final BitVector word) {
		final int length = (int) word.length();
		Node n = root;
		long[] path;

		int pos = 0; // Current position in word

		while(n != null) {
			// We found the current path: we go searching for left and right delimiters.
			if (pos == length) break;

			path = n.path;

			if (path != null) {
				final int maxLength = Math.min(length - pos, n.pathLength);
				int i;
				for(i = 0; i < maxLength; i++) if (word.getBoolean(pos + i) != get(path, i)) break;
				// Incompatible with current path--we return the empty interval.
				if (i < maxLength) return Intervals.EMPTY_INTERVAL;

				pos += i;

				// Completely contained in the current path: we go searching for left and right delimiters.
				if (pos == length) break;
			}

			n = word.getBoolean(pos++) ? n.right : n.left;
		}

		// If n == null, we did not found the path. Otherwise, it's the current node,
		// and we must search for left and right delimiters.
		if (n == null) return Intervals.EMPTY_INTERVAL;

		Node l = n;
		// Searching for the left extreme...
		while(l.word < 0) l = l.left != null ? l.left : l.right;
		// Searching for the right extreme, unless we're on a leaf.
		while(! n.isLeaf()) n = n.right != null ? n.right : n.left;

		return Interval.valueOf(l.word, n.word);

	}


	/** Returns an interval given by the smallest and the largest word in the trie starting with
	 * the word returned by the given iterator.
	 *
	 * @param iterator an iterator.
	 * @return  an interval given by the smallest and the largest word in the trie
	 * that start with the word returned by <code>iterator</code> (thus, the {@linkplain Intervals#EMPTY_INTERVAL empty inteval}
	 * if no such words exist).
	 * @see #getInterval(BitVector)
	 */

	public Interval getInterval(final BooleanIterator iterator) {
		Node n = root;
		boolean mismatch = false;
		long[] path;
		int pathLength;

		while(n != null) {
			// We found the current path: we go searching for left and right delimiters.
			if (! iterator.hasNext()) break;

			pathLength = n.pathLength;
			if (pathLength != 0) {
				int i;
				path = n.path;
				for(i = 0; i < pathLength && iterator.hasNext(); i++) if ((mismatch = (iterator.nextBoolean() != get(path, i)))) break;
				// Incompatible with current path--we return the empty interval.
				if (mismatch) return Intervals.EMPTY_INTERVAL;

				// Completely contained in the current path: we go searching for left and right delimiters.
				if (! iterator.hasNext()) break;

			}

			n = iterator.nextBoolean() ? n.right : n.left;
		}

		// If n == null, we did not found the path. Otherwise, it's the current node,
		// and we must search for left and right delimiters.
		if (n == null) return Intervals.EMPTY_INTERVAL;

		Node l = n;
		// Searching for the left extreme...
		while(l.word < 0) l = l.left != null ? l.left : l.right;
		// Searching for the right extreme, unless we're on a leaf.
		while (! n.isLeaf()) n = n.right != null ? n.right : n.left;

		return Interval.valueOf(l.word, n.word);

	}


	/** Returns an approximated interval around the specified word.
	 *
	 * <P>Given a word <var>w</var>, the corresponding approximated  interval is
	 * defined as follows: if the words in the approximator are thought of as left interval extremes in a
	 * larger lexicographically ordered set of words, and we number these word intervals using the
	 * indices of their left extremes, then the first word extending <var>w</var> would be in the
	 * word interval given by the left extreme of the interval returned by this method, whereas
	 * the last word extending <var>w</var> would be in the word interval given by the right
	 * extreme of the interval returned by this method. If no word in the larger set could possibly extend
	 * <var>w</var> (because <var>w</var> is smaller than the lexicographically smallest word in the approximator)
	 * the result is just an {@linkplain it.unimi.dsi.util.Intervals#EMPTY_INTERVAL empty interval}.
	 *
	 * @param element an element.
	 * @return an approximated interval around the specified word.
	 * @see #getApproximatedInterval(BooleanIterator)
	 */

	public Interval getApproximatedInterval(final T element) {
		final BitVector word = transformationStrategy.toBitVector(element);
		final int length = (int) word.length();
		Node n = root;
		long[] path;
		boolean exactMatch = false, mismatch = false, nextBit;

		int pos = 0; // Current position in word

		while(n != null) {
			// We found the current path: we go searching for left and right delimiters.

			path = n.path;

			if (pos == length) {
				if (n.word >= 0 && path == null) exactMatch = true;
				break;
			}

			if (path != null) {
				final int maxLength = Math.min(length - pos, n.pathLength);
				int i;
				for(i = 0; i < maxLength; i++) if (mismatch = (word.getBoolean(pos + i) != get(path, i))) break;

				if (mismatch) {
					// System.err.println("Exit 1");
					// A mismatch. In this case, it is guaranteed that all
					// strings starting with the prefix examined so far lie
					// in a single block. The block index depends, however
					// on the bit that went wrong.
					if (get(path, i)) {
						while(n.word < 0) n = n.left != null ? n.left : n.right;
						return n.word > 0 ? Interval.valueOf(n.word - 1) : Intervals.EMPTY_INTERVAL;
					}
					else {
						while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
						return Interval.valueOf(n.word);
					}
				}

				pos += i;

				// Completely contained in the current path
				if (pos == length) {
					if (i == n.pathLength && n.word >= 0) exactMatch = true;
					break;
				}

			}

			if (n.isLeaf()) break;

			nextBit = word.getBoolean(pos++);

			// We would like to take an impossible turn. This case is similar to
			// prefix mismatches, with subtly different off-by-ones.
			if (nextBit && n.right == null) {
				while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
				return Interval.valueOf(n.word);
			}
			else if (! nextBit && n.left == null) {
				while(n.word < 0) n = n.left != null ? n.left : n.right;
				return Interval.valueOf(n.word);
			}

			n = nextBit ? n.right : n.left;
		}

		Node l = n;
		// Searching for the left extreme...

		//System.err.println("Going for exit 2: l:" + l + " n:" + n);

		while(l.word < 0) l = l.left != null ? l.left : l.right;

		// Searching for the right extreme, unless we're on a leaf.
		while (! n.isLeaf()) n = n.right != null ? n.right : n.left;

		// System.err.println("Following exit 2: l:" + l + " n:" + n);

		// If did not find an exact match and l.word is 0 we are lexicographically before every word.
		if (pos == length && ! exactMatch) {
			if (l.word == 0) return mismatch ? Intervals.EMPTY_INTERVAL : Interval.valueOf(l.word, n.word);
			else return Interval.valueOf(l.word - 1, n.word);
		}

		// System.err.println("Exit 2 (exactMatch: " + exactMatch +")");
		return Interval.valueOf(l.word, n.word);
	}



	/** Returns an approximated prefix interval around the word returned by the specified iterator.
	 *
	 * @param iterator an iterator.
	 * @return an approximated interval around the specified word: if the words in this trie
	 * are thought of as left interval extremes in a larger lexicographically ordered set of words,
	 * and we number these word intervals using the indices of their left extremes,
	 * then the first word extending <code>word</code> would be in the word interval given by
	 * the left extreme of the {@link Interval} returned by this method, whereas
	 * the last word extending <code>word</code> would be in the word
	 * interval given by the right extreme of the {@link Interval} returned by this method.
	 * @see #getApproximatedInterval(Object)
	 */

	public Interval getApproximatedInterval(final BooleanIterator iterator) {
		Node n = root;
		long[] path;
		boolean exactMatch = false, mismatch = false, nextBit;

		for(;;) {
			// We found the current path: we go searching for left and right delimiters.

			path = n.path;

			if (! iterator.hasNext()) {
				if (n.word >= 0 && path == null) exactMatch = true;
				break;
			}

			if (path != null) {
				int i;
				final int pathSize = n.pathLength;
				for(i = 0; i < pathSize && iterator.hasNext(); i++) if ((mismatch = (iterator.nextBoolean() != get(path, i)))) break;

				if (mismatch) {
					// System.err.println("Exit 1");
					// A mismatch. In this case, it is guaranteed that all
					// strings starting with the prefix examined so far lie
					// in a single block. The block index depends, however
					// on the bit that went wrong.
					if (get(path, i)) {
						while(n.word < 0) n = n.left != null ? n.left : n.right;
						return n.word > 0 ? Interval.valueOf(n.word - 1) : Intervals.EMPTY_INTERVAL;
					}
					else {
						while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
						return Interval.valueOf(n.word);
					}
				}

				// Completely contained in the current path
				if (! iterator.hasNext()) {
					if (i == pathSize && n.word >= 0) exactMatch = true;
					break;
				}

			}

			if (n.isLeaf()) break;

			nextBit = iterator.nextBoolean();

			// We would like to take an impossible turn. This case is similar to
			// prefix mismatches, with subtly different off-by-ones.
			if (nextBit && n.right == null) {
				while(! n.isLeaf()) n = n.right != null ? n.right : n.left;
				return Interval.valueOf(n.word);
			}
			else if (! nextBit && n.left == null) {
				while(n.word < 0) n = n.left != null ? n.left : n.right;
				return Interval.valueOf(n.word);
			}

			n = nextBit ? n.right : n.left;
		}

		Node l = n;
		// Searching for the left extreme...

		//System.err.println("Going for exit 2: l:" + l + " n:" + n);

		while(l.word < 0) l = l.left != null ? l.left : l.right;

		// Searching for the right extreme, unless we're on a leaf.
		while (! n.isLeaf()) n = n.right != null ? n.right : n.left;

		// If did not find an exact match and l.word is 0 we are lexicographically before every word.
		if (! iterator.hasNext() && ! exactMatch) {
			if (l.word == 0) return mismatch ? Intervals.EMPTY_INTERVAL : Interval.valueOf(0);
			else return Interval.valueOf(l.word - 1, n.word);
		}

		// System.err.println("Exit 2 (hasNext: " +iterator.hasNext() + " exactMatch: " + exactMatch +")");
		return Interval.valueOf(l.word, n.word);
	}



	private void recToString(final Node n, final MutableString printPrefix, final MutableString result, final MutableString path, final int level) {
		if (n == null) return;

		//System.err.println("Called with prefix " + printPrefix);

		result.append(printPrefix).append('(').append(level).append(')');

		if (n.path != null) {
			path.append(LongArrayBitVector.wrap(n.path, n.pathLength));
			result.append(" path:").append(LongArrayBitVector.wrap(n.path, n.pathLength));
		}
		if (n.word >= 0) result.append(" word: ").append(n.word).append(" (").append(path).append(')');

		result.append('\n');

		path.append('0');
		recToString(n.left, printPrefix.append('\t').append("0 => "), result, path, level + 1);
		path.charAt(path.length() - 1, '1');
		recToString(n.right, printPrefix.replace(printPrefix.length() - 5, printPrefix.length(), "1 => "), result, path, level + 1);
		path.delete(path.length() - 1, path.length());
		printPrefix.delete(printPrefix.length() - 6, printPrefix.length());

		//System.err.println("Path now: " + path + " Going to delete from " + (path.length() - n.pathLength));

		path.delete(path.length() - n.pathLength, path.length());
	}

	@Override
	public String toString() {
		final MutableString s = new MutableString();
		recToString(root, new MutableString(), s, new MutableString(), 0);
		return s.toString();
	}
}
