/*
 * Copyright (c) 2018, Google LLC. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

/**
 * @test 8200301 8201194
 * @summary deduplicate lambda methods with the same body, target type, and captured state
 * @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api
 *     jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp
 *     jdk.compiler/com.sun.tools.javac.file jdk.compiler/com.sun.tools.javac.main
 *     jdk.compiler/com.sun.tools.javac.tree jdk.compiler/com.sun.tools.javac.util
 * @run main DeduplicationTest
 */
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;

import com.sun.source.util.JavacTask;
import com.sun.source.util.TaskEvent;
import com.sun.source.util.TaskEvent.Kind;
import com.sun.source.util.TaskListener;
import com.sun.tools.classfile.Attribute;
import com.sun.tools.classfile.BootstrapMethods_attribute;
import com.sun.tools.classfile.BootstrapMethods_attribute.BootstrapMethodSpecifier;
import com.sun.tools.classfile.ClassFile;
import com.sun.tools.classfile.ConstantPool.CONSTANT_MethodHandle_info;
import com.sun.tools.javac.api.ClientCodeWrapper.Trusted;
import com.sun.tools.javac.api.JavacTool;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.comp.TreeDiffer;
import com.sun.tools.javac.comp.TreeHasher;
import com.sun.tools.javac.file.JavacFileManager;
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
import com.sun.tools.javac.tree.JCTree.JCExpression;
import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.JCTree.JCLambda;
import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
import com.sun.tools.javac.tree.JCTree.JCTypeCast;
import com.sun.tools.javac.tree.JCTree.Tag;
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context;
import com.sun.tools.javac.util.JCDiagnostic;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import javax.tools.Diagnostic;
import javax.tools.DiagnosticListener;
import javax.tools.JavaFileObject;

public class DeduplicationTest {

    public static void main(String[] args) throws Exception {
        JavacFileManager fileManager = new JavacFileManager(new Context(), false, UTF_8);
        JavacTool javacTool = JavacTool.create();
        Listener diagnosticListener = new Listener();
        Path testSrc = Paths.get(System.getProperty("test.src"));
        Path file = testSrc.resolve("Deduplication.java");
        JavacTask task =
                javacTool.getTask(
                        null,
                        null,
                        diagnosticListener,
                        Arrays.asList(
                                "-d",
                                ".",
                                "-XDdebug.dumpLambdaToMethodDeduplication",
                                "-XDdebug.dumpLambdaToMethodStats"),
                        null,
                        fileManager.getJavaFileObjects(file));
        Map<JCLambda, JCLambda> dedupedLambdas = new LinkedHashMap<>();
        task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas));
        Iterable<? extends JavaFileObject> generated = task.generate();
        if (!diagnosticListener.unexpected.isEmpty()) {
            throw new AssertionError(
                    diagnosticListener
                            .unexpected
                            .stream()
                            .map(
                                    d ->
                                            String.format(
                                                    "%s: %s",
                                                    d.getCode(), d.getMessage(Locale.getDefault())))
                            .collect(joining(", ", "unexpected diagnostics: ", "")));
        }

        // Assert that each group of lambdas was deduplicated.
        Map<JCLambda, JCLambda> actual = diagnosticListener.deduplicationTargets();
        dedupedLambdas.forEach(
                (k, v) -> {
                    if (!actual.containsKey(k)) {
                        throw new AssertionError("expected " + k + " to be deduplicated");
                    }
                    if (!v.equals(actual.get(k))) {
                        throw new AssertionError(
                                String.format(
                                        "expected %s to be deduplicated to:\n  %s\nwas:  %s",
                                        k, v, actual.get(v)));
                    }
                });

        // Assert that the output contains only the canonical lambdas, and not the deduplicated
        // lambdas.
        Set<String> bootstrapMethodNames = new TreeSet<>();
        for (JavaFileObject output : generated) {
            ClassFile cf = ClassFile.read(output.openInputStream());
            BootstrapMethods_attribute bsm =
                    (BootstrapMethods_attribute) cf.getAttribute(Attribute.BootstrapMethods);
            for (BootstrapMethodSpecifier b : bsm.bootstrap_method_specifiers) {
                bootstrapMethodNames.add(
                        ((CONSTANT_MethodHandle_info)
                                        cf.constant_pool.get(b.bootstrap_arguments[1]))
                                .getCPRefInfo()
                                .getNameAndTypeInfo()
                                .getName());
            }
        }
        Set<String> deduplicatedNames =
                diagnosticListener
                        .expectedLambdaMethods()
                        .stream()
                        .map(s -> s.getSimpleName().toString())
                        .sorted()
                        .collect(toSet());
        if (!deduplicatedNames.equals(bootstrapMethodNames)) {
            throw new AssertionError(
                    String.format(
                            "expected deduplicated methods: %s, but saw: %s",
                            deduplicatedNames, bootstrapMethodNames));
        }
    }

    /** Returns the parameter symbols of the given lambda. */
    private static List<Symbol> paramSymbols(JCLambda lambda) {
        return lambda.params.stream().map(x -> x.sym).collect(toList());
    }

    /** A diagnostic listener that records debug messages related to lambda desugaring. */
    @Trusted
    static class Listener implements DiagnosticListener<JavaFileObject> {

        /** A map from method symbols to lambda trees for desugared lambdas. */
        final Map<MethodSymbol, JCLambda> lambdaMethodSymbolsToTrees = new LinkedHashMap<>();

        /**
         * A map from lambda trees that were deduplicated to the method symbol of the canonical
         * lambda implementation method they were deduplicated to.
         */
        final Map<JCLambda, MethodSymbol> deduped = new LinkedHashMap<>();

        final List<Diagnostic<? extends JavaFileObject>> unexpected = new ArrayList<>();

        @Override
        public void report(Diagnostic<? extends JavaFileObject> diagnostic) {
            JCDiagnostic d = (JCDiagnostic) diagnostic;
            switch (d.getCode()) {
                case "compiler.note.lambda.stat":
                    lambdaMethodSymbolsToTrees.put(
                            (MethodSymbol) d.getArgs()[1],
                            (JCLambda) d.getDiagnosticPosition().getTree());
                    break;
                case "compiler.note.verbose.l2m.deduplicate":
                    deduped.put(
                            (JCLambda) d.getDiagnosticPosition().getTree(),
                            (MethodSymbol) d.getArgs()[0]);
                    break;
                default:
                    unexpected.add(diagnostic);
            }
        }

        /** Returns expected lambda implementation method symbols. */
        Set<MethodSymbol> expectedLambdaMethods() {
            return lambdaMethodSymbolsToTrees
                    .entrySet()
                    .stream()
                    .filter(e -> !deduped.containsKey(e.getValue()))
                    .map(Map.Entry::getKey)
                    .collect(toSet());
        }

        /**
         * Returns a mapping from deduplicated lambda trees to the tree of the canonical lambda they
         * were deduplicated to.
         */
        Map<JCLambda, JCLambda> deduplicationTargets() {
            return deduped.entrySet()
                    .stream()
                    .collect(
                            toMap(
                                    Map.Entry::getKey,
                                    e -> lambdaMethodSymbolsToTrees.get(e.getValue()),
                                    (a, b) -> {
                                        throw new AssertionError();
                                    },
                                    LinkedHashMap::new));
        }
    }

    /**
     * A task listener that tests {@link TreeDiffer} and {@link TreeHasher} on all lambda trees in a
     * compilation, post-analysis.
     */
    private static class TreeDiffHashTaskListener implements TaskListener {

        /**
         * A map from deduplicated lambdas to the canonical lambda they are expected to be
         * deduplicated to.
         */
        private final Map<JCLambda, JCLambda> dedupedLambdas;

        public TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas) {
            this.dedupedLambdas = dedupedLambdas;
        }

        @Override
        public void finished(TaskEvent e) {
            if (e.getKind() != Kind.ANALYZE) {
                return;
            }
            // Scan the compilation for calls to a varargs method named 'group', whose arguments
            // are a group of lambdas that are equivalent to each other, but distinct from all
            // lambdas in the compilation unit outside of that group.
            List<List<JCLambda>> lambdaGroups = new ArrayList<>();
            new TreeScanner() {
                @Override
                public void visitApply(JCMethodInvocation tree) {
                    if (tree.getMethodSelect().getTag() == Tag.IDENT
                            && ((JCIdent) tree.getMethodSelect())
                                    .getName()
                                    .contentEquals("group")) {
                        List<JCLambda> xs = new ArrayList<>();
                        for (JCExpression arg : tree.getArguments()) {
                            if (arg instanceof JCTypeCast) {
                                arg = ((JCTypeCast) arg).getExpression();
                            }
                            xs.add((JCLambda) arg);
                        }
                        lambdaGroups.add(xs);
                    }
                    super.visitApply(tree);
                }
            }.scan((JCCompilationUnit) e.getCompilationUnit());
            for (int i = 0; i < lambdaGroups.size(); i++) {
                List<JCLambda> curr = lambdaGroups.get(i);
                JCLambda first = null;
                // Assert that all pairwise combinations of lambdas in the group are equal, and
                // hash to the same value.
                for (JCLambda lhs : curr) {
                    if (first == null) {
                        first = lhs;
                    } else {
                        dedupedLambdas.put(lhs, first);
                    }
                    for (JCLambda rhs : curr) {
                        if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
                                .scan(lhs.body, rhs.body)) {
                            throw new AssertionError(
                                    String.format(
                                            "expected lambdas to be equal\n%s\n%s", lhs, rhs));
                        }
                        if (TreeHasher.hash(lhs, paramSymbols(lhs))
                                != TreeHasher.hash(rhs, paramSymbols(rhs))) {
                            throw new AssertionError(
                                    String.format(
                                            "expected lambdas to hash to the same value\n%s\n%s",
                                            lhs, rhs));
                        }
                    }
                }
                // Assert that no lambdas in a group are equal to any lambdas outside that group,
                // or hash to the same value as lambda outside the group.
                // (Note that the hash collisions won't result in correctness problems but could
                // regress performs, and do not currently occurr for any of the test inputs.)
                for (int j = 0; j < lambdaGroups.size(); j++) {
                    if (i == j) {
                        continue;
                    }
                    for (JCLambda lhs : curr) {
                        for (JCLambda rhs : lambdaGroups.get(j)) {
                            if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
                                    .scan(lhs.body, rhs.body)) {
                                throw new AssertionError(
                                        String.format(
                                                "expected lambdas to not be equal\n%s\n%s",
                                                lhs, rhs));
                            }
                            if (TreeHasher.hash(lhs, paramSymbols(lhs))
                                    == TreeHasher.hash(rhs, paramSymbols(rhs))) {
                                throw new AssertionError(
                                        String.format(
                                                "expected lambdas to hash to different values\n%s\n%s",
                                                lhs, rhs));
                            }
                        }
                    }
                }
            }
            lambdaGroups.clear();
        }
    }
}
