/*
 * Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

/*
 * @test
 * @enablePreview
 * @run testng TestReshape
 */

import java.lang.foreign.MemoryLayout;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.LongStream;

import org.testng.annotations.*;
import static org.testng.Assert.*;

public class TestReshape {

    @Test(dataProvider = "shapes")
    public void testReshape(MemoryLayout layout, long[] expectedShape) {
        long flattenedSize = LongStream.of(expectedShape).reduce(1L, Math::multiplyExact);
        SequenceLayout seq_flattened = MemoryLayout.sequenceLayout(flattenedSize, layout);
        assertDimensions(seq_flattened, flattenedSize);
        for (long[] shape : new Shape(expectedShape)) {
            SequenceLayout seq_shaped = seq_flattened.reshape(shape);
            assertDimensions(seq_shaped, expectedShape);
            assertEquals(seq_shaped.flatten(), seq_flattened);
        }
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidReshape() {
        SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
        seq.reshape(3, 2);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testBadReshapeInference() {
        SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
        seq.reshape(-1, -1);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testBadReshapeParameterZero() {
        SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
        seq.reshape(0, 4);
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testBadReshapeParameterNegative() {
        SequenceLayout seq = MemoryLayout.sequenceLayout(4, ValueLayout.JAVA_INT);
        seq.reshape(-2, 2);
    }

    static void assertDimensions(SequenceLayout layout, long... dims) {
        SequenceLayout prev = null;
        for (int i = 0 ; i < dims.length ; i++) {
            if (prev != null) {
                layout = (SequenceLayout)prev.elementLayout();
            }
            assertEquals(layout.elementCount(), dims[i]);
            prev = layout;
        }
    }

    static class Shape implements Iterable<long[]> {
        long[] shape;

        Shape(long... shape) {
            this.shape = shape;
        }

        public Iterator<long[]> iterator() {
            List<long[]> shapes = new ArrayList<>();
            shapes.add(shape);
            for (int i = 0 ; i < shape.length ; i++) {
                long[] inferredShape = shape.clone();
                inferredShape[i] = -1;
                shapes.add(inferredShape);
            }
            return shapes.iterator();
        }
    }

    static MemoryLayout POINT = MemoryLayout.structLayout(
            ValueLayout.JAVA_INT,
            ValueLayout.JAVA_INT
    );

    @DataProvider(name = "shapes")
    Object[][] shapes() {
        return new Object[][] {
                { ValueLayout.JAVA_BYTE, new long[] { 256 } },
                { ValueLayout.JAVA_BYTE, new long[] { 16, 16 } },
                { ValueLayout.JAVA_BYTE, new long[] { 4, 4, 4, 4 } },
                { ValueLayout.JAVA_BYTE, new long[] { 2, 8, 16 } },
                { ValueLayout.JAVA_BYTE, new long[] { 16, 8, 2 } },
                { ValueLayout.JAVA_BYTE, new long[] { 8, 16, 2 } },

                { ValueLayout.JAVA_SHORT, new long[] { 256 } },
                { ValueLayout.JAVA_SHORT, new long[] { 16, 16 } },
                { ValueLayout.JAVA_SHORT, new long[] { 4, 4, 4, 4 } },
                { ValueLayout.JAVA_SHORT, new long[] { 2, 8, 16 } },
                { ValueLayout.JAVA_SHORT, new long[] { 16, 8, 2 } },
                { ValueLayout.JAVA_SHORT, new long[] { 8, 16, 2 } },

                { ValueLayout.JAVA_CHAR, new long[] { 256 } },
                { ValueLayout.JAVA_CHAR, new long[] { 16, 16 } },
                { ValueLayout.JAVA_CHAR, new long[] { 4, 4, 4, 4 } },
                { ValueLayout.JAVA_CHAR, new long[] { 2, 8, 16 } },
                { ValueLayout.JAVA_CHAR, new long[] { 16, 8, 2 } },
                { ValueLayout.JAVA_CHAR, new long[] { 8, 16, 2 } },

                { ValueLayout.JAVA_INT, new long[] { 256 } },
                { ValueLayout.JAVA_INT, new long[] { 16, 16 } },
                { ValueLayout.JAVA_INT, new long[] { 4, 4, 4, 4 } },
                { ValueLayout.JAVA_INT, new long[] { 2, 8, 16 } },
                { ValueLayout.JAVA_INT, new long[] { 16, 8, 2 } },
                { ValueLayout.JAVA_INT, new long[] { 8, 16, 2 } },

                { ValueLayout.JAVA_LONG, new long[] { 256 } },
                { ValueLayout.JAVA_LONG, new long[] { 16, 16 } },
                { ValueLayout.JAVA_LONG, new long[] { 4, 4, 4, 4 } },
                { ValueLayout.JAVA_LONG, new long[] { 2, 8, 16 } },
                { ValueLayout.JAVA_LONG, new long[] { 16, 8, 2 } },
                { ValueLayout.JAVA_LONG, new long[] { 8, 16, 2 } },

                { ValueLayout.JAVA_FLOAT, new long[] { 256 } },
                { ValueLayout.JAVA_FLOAT, new long[] { 16, 16 } },
                { ValueLayout.JAVA_FLOAT, new long[] { 4, 4, 4, 4 } },
                { ValueLayout.JAVA_FLOAT, new long[] { 2, 8, 16 } },
                { ValueLayout.JAVA_FLOAT, new long[] { 16, 8, 2 } },
                { ValueLayout.JAVA_FLOAT, new long[] { 8, 16, 2 } },

                { ValueLayout.JAVA_DOUBLE, new long[] { 256 } },
                { ValueLayout.JAVA_DOUBLE, new long[] { 16, 16 } },
                { ValueLayout.JAVA_DOUBLE, new long[] { 4, 4, 4, 4 } },
                { ValueLayout.JAVA_DOUBLE, new long[] { 2, 8, 16 } },
                { ValueLayout.JAVA_DOUBLE, new long[] { 16, 8, 2 } },
                { ValueLayout.JAVA_DOUBLE, new long[] { 8, 16, 2 } },

                { POINT, new long[] { 256 } },
                { POINT, new long[] { 16, 16 } },
                { POINT, new long[] { 4, 4, 4, 4 } },
                { POINT, new long[] { 2, 8, 16 } },
                { POINT, new long[] { 16, 8, 2 } },
                { POINT, new long[] { 8, 16, 2 } },
        };
    }
}
