/*
 *  Licensed to the Apache Software Foundation (ASF) under one or more
 *  contributor license agreements.  See the NOTICE file distributed with
 *  this work for additional information regarding copyright ownership.
 *  The ASF licenses this file to You under the Apache License, Version 2.0
 *  (the "License"); you may not use this file except in compliance with
 *  the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.apache.tomcat.websocket.server;

import java.net.URI;
import java.util.concurrent.atomic.AtomicInteger;

import jakarta.servlet.ServletContextEvent;
import jakarta.websocket.ClientEndpoint;
import jakarta.websocket.CloseReason;
import jakarta.websocket.ContainerProvider;
import jakarta.websocket.DeploymentException;
import jakarta.websocket.OnClose;
import jakarta.websocket.OnError;
import jakarta.websocket.OnMessage;
import jakarta.websocket.OnOpen;
import jakarta.websocket.Session;
import jakarta.websocket.WebSocketContainer;
import jakarta.websocket.server.ServerContainer;
import jakarta.websocket.server.ServerEndpointConfig;

import org.junit.Assert;
import org.junit.Test;

import org.apache.catalina.Context;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.WebSocketBaseTest;

public class TestCloseBug58624 extends WebSocketBaseTest {

    @Test
    public void testOnErrorNotCalledWhenClosingConnection() throws Throwable {
        Tomcat tomcat = getTomcatInstance();
        // No file system docBase required
        Context ctx = getProgrammaticRootContext();
        ctx.addApplicationListener(Bug58624ServerConfig.class.getName());
        Tomcat.addServlet(ctx, "default", new DefaultServlet());
        ctx.addServletMappingDecoded("/", "default");

        WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();

        tomcat.start();

        Bug58624ClientEndpoint client = new Bug58624ClientEndpoint();
        URI uri = new URI("ws://localhost:" + getPort() + Bug58624ServerConfig.PATH);

        Session session = wsContainer.connectToServer(client, uri);

        // Wait for session to open on the server
        int count = 0;
        while (count < 50 && Bug58624ServerEndpoint.getOpenSessionCount() == 0) {
            count++;
            Thread.sleep(100);
        }
        Assert.assertNotEquals(0, Bug58624ServerEndpoint.getOpenSessionCount());

        // Now close the session
        session.close();

        // Wait for session to close on the server
        count = 0;
        while (count < 50 && Bug58624ServerEndpoint.getOpenSessionCount() > 0) {
            count++;
            Thread.sleep(100);
        }
        Assert.assertEquals(0, Bug58624ServerEndpoint.getOpenSessionCount());

        // Ensure no errors were reported on the server
        Assert.assertEquals(0, Bug58624ServerEndpoint.getErrorCount());

        if (client.getError() != null) {
            throw client.getError();
        }
    }

    @ClientEndpoint
    public static class Bug58624ClientEndpoint {

        private volatile Throwable t;


        @OnError
        public void onError(Throwable t) {
            this.t = t;
        }


        public Throwable getError() {
            return this.t;
        }
    }

    public static class Bug58624ServerConfig extends WsContextListener {

        public static final String PATH = "/bug58624";


        @Override
        public void contextInitialized(ServletContextEvent sce) {
            super.contextInitialized(sce);

            ServerContainer sc = (ServerContainer) sce.getServletContext()
                    .getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);

            ServerEndpointConfig sec = ServerEndpointConfig.Builder.create(Bug58624ServerEndpoint.class, PATH).build();

            try {
                sc.addEndpoint(sec);
            } catch (DeploymentException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static class Bug58624ServerEndpoint {

        private static AtomicInteger openSessionCount = new AtomicInteger(0);
        private static AtomicInteger errorCount = new AtomicInteger(0);

        public static int getOpenSessionCount() {
            return openSessionCount.get();
        }

        public static int getErrorCount() {
            return errorCount.get();
        }

        @OnOpen
        public void onOpen() {
            openSessionCount.incrementAndGet();
        }


        @OnMessage
        public void onMessage(@SuppressWarnings("unused") Session session, String message) {
            System.out.println("Received message " + message);
        }


        @OnError
        public void onError(Throwable t) {
            errorCount.incrementAndGet();
            t.printStackTrace();
        }


        @OnClose
        public void onClose(@SuppressWarnings("unused") CloseReason cr) {
            openSessionCount.decrementAndGet();
        }
    }
}
