/*
 * DSI utilities
 *
 * Copyright (C) 2005-2023 Sebastiano Vigna
 *
 * This program and the accompanying materials are made available under the
 * terms of the GNU Lesser General Public License v2.1 or later,
 * which is available at
 * http://www.gnu.org/licenses/old-licenses/lgpl-2.1-standalone.html,
 * or the Apache Software License 2.0, which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.
 *
 * SPDX-License-Identifier: LGPL-2.1-or-later OR Apache-2.0
 */

package it.unimi.dsi.compression;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;

import it.unimi.dsi.bits.BitVector;
import it.unimi.dsi.bits.LongArrayBitVector;
import it.unimi.dsi.fastutil.booleans.BooleanIterator;
import it.unimi.dsi.io.InputBitStream;

/** A decoder that follows 0/1 labelled paths in a tree.
 *
 *  <p>Additional, the {@link #buildCodes()} method returns a vector
 *  of codewords corresponding to the paths of an instance of this class. Conversely,
 *  the {@linkplain #TreeDecoder(BitVector[], int[]) codeword-based constructor} builds
 *  a tree out of the codewords generated by root-to-leaf paths.
 */

public final class TreeDecoder implements Decoder, Serializable {
	private static final long serialVersionUID = 2L;
	private static final boolean DEBUG = false;

	/** A internal node of the decoding tree. */
	public static class Node implements Serializable {
		private static final long serialVersionUID = 1L;
		public Node left, right;
	}

	/** A leaf node of the decoding tree. */
	public static class LeafNode extends Node {
		private static final long serialVersionUID = 1L;
		public final int symbol;

		/** Creates a leaf node.
		 * @param symbol the symbol for this node.
		 */
		public LeafNode(final int symbol) {
			this.symbol = symbol;
		}
	}

	/** The root of the decoding tree. */
	private final Node root;
	/** The number of symbols in this decoder. */
	private final int n;

	/** Creates a new codeword-based decoder using the given tree. It
	 * is responsibility of the caller that the tree is well-formed,
	 * that is, that all internal nodes are instances of {@link TreeDecoder.Node}
	 * and all leaf nodes are instances of  {@link TreeDecoder.LeafNode}.
	 *
	 * @param root the root of a decoding tree.
	 * @param n the number of leaves (symbols).
	 */
	public TreeDecoder(final Node root, final int n) {
		this.root = root;
		this.n = n;
	}

	/** Creates a new codeword-based decoder starting from a set
	 * of complete, lexicographically ordered codewords. It
	 * is responsibility of the caller that the tree is well-formed,
	 * that is, that the provided codewords are exactly the root-to-leaf
	 * paths of such a tree.
	 *
	 * @param lexSortedCodeWord a vector of lexically sorted codeWords.
	 * @param symbol a mapping from codewords to symbols.
	 */
	public TreeDecoder(final BitVector[] lexSortedCodeWord, final int[] symbol) {
		this(buildTree(lexSortedCodeWord, symbol, 0, 0, lexSortedCodeWord.length), lexSortedCodeWord.length);
	}


	private static Node buildTree(final BitVector lexSortedCodeWords[], final int[] symbol, final int prefix, final int offset, final int length) {
		if (DEBUG) {
			System.err.println("****** " + offset + " " + length);
			System.err.println(Arrays.toString(lexSortedCodeWords));
			for(int i = 0; i < length; i++) {
				System.err.print(lexSortedCodeWords[offset + i].length() + "\t");
				for(int j = 0; j < lexSortedCodeWords[offset + i].length(); j++) System.err.print(lexSortedCodeWords[offset + i].getBoolean(j) ? 1 : 0);
				System.err.println();
			}
		}

		if (length == 1) return new LeafNode(symbol[offset]);
		for(int i = length - 1; i-- != 0;)
			if (lexSortedCodeWords[offset + i].getBoolean(prefix) != lexSortedCodeWords[offset + i + 1].getBoolean(prefix)) {
				final Node node = new Node();
				node.left = buildTree(lexSortedCodeWords, symbol, prefix + 1, offset, i + 1);
				node.right = buildTree(lexSortedCodeWords, symbol, prefix + 1, offset + i + 1, length - i - 1);
				return node;
			}

		throw new IllegalStateException();
	}


	@Override
	public int decode(final BooleanIterator iterator) {
		Node n = root;
		while(! (n instanceof LeafNode))
			n = iterator.nextBoolean() ? n.right : n.left;
		return ((LeafNode)n).symbol;
	}

	@Override
	public int decode(final InputBitStream ibs) throws IOException {
		Node n = root;
		while(! (n instanceof LeafNode))
			n = ibs.readBit() == 0 ? n.left : n.right;
		return ((LeafNode)n).symbol;
	}

	/** Populates the codeword vector by scanning recursively
	 * the decoding tree.
	 *
	 * @param node a subtree of the decoding tree.
	 * @param prefix the path leading to <code>n</code>.
	 */
	private void buildCodes(final BitVector[] codeWord, final TreeDecoder.Node node, final BitVector prefix) {

		if (node instanceof TreeDecoder.LeafNode) {
			codeWord[((TreeDecoder.LeafNode)node).symbol] = prefix;
			return;
		}

		BitVector bitVector = prefix.copy();
		bitVector.length(bitVector.length() + 1);
		buildCodes(codeWord, node.left, bitVector);

		bitVector = prefix.copy();
		bitVector.length(bitVector.length() + 1);
		bitVector.set(bitVector.length() - 1);

		buildCodes(codeWord, node.right, bitVector);
	}

	/** Generate the codewords corresponding to this tree decoder.
	 *
	 * @return a vector of codewords for this decoder.
	 */
	public BitVector[] buildCodes() {
		final BitVector[] codeWord = new BitVector[n];
		buildCodes(codeWord, root, LongArrayBitVector.getInstance());
		return codeWord;
	}

	private static void visit(Node node, final LongArrayBitVector bitVector) {
		if (node instanceof LeafNode) return;

		do {
			bitVector.add(true);
			visit(node.left, bitVector);
			bitVector.add(false);
		} while(! ((node = node.right) instanceof LeafNode));
	}

	public LongArrayBitVector succinctRepresentation() {
		final LongArrayBitVector bitVector = LongArrayBitVector.getInstance();

		bitVector.add(true); // Fake open parenthesis
		if (root != null) visit(root, bitVector);
		bitVector.add(false);  // Fake closed parenthesis
		bitVector.trim();

		return bitVector;
	}
}
