/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tomcat.websocket;

import java.net.URI;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.ContainerProvider;
import jakarta.websocket.Session;
import jakarta.websocket.WebSocketContainer;

import org.junit.Assert;
import org.junit.Test;

import org.apache.catalina.Context;
import org.apache.catalina.authenticator.AuthenticatorBase;
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.util.descriptor.web.LoginConfig;
import org.apache.tomcat.util.descriptor.web.SecurityCollection;
import org.apache.tomcat.util.descriptor.web.SecurityConstraint;
import org.apache.tomcat.websocket.TesterMessageCountClient.BasicText;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;

/*
 * Tests WebSocket connections via a forward proxy.
 *
 * These tests have been successfully used with Apache Web Server (httpd)
 * configured with the following:
 *
 * Listen 8888
 * <VirtualHost *:8888>
 *     ProxyRequests On
 *     ProxyVia On
 *     AllowCONNECT 0-65535
 * </VirtualHost>
 *
 * Listen 8889
 * <VirtualHost *:8889>
 *     ProxyRequests On
 *     ProxyVia On
 *     AllowCONNECT 0-65535
 *     <Proxy *>
 *         Order deny,allow
 *         Allow from all
 *         AuthType Basic
 *         AuthName "Proxy Password Required"
 *         AuthUserFile password.file
 *         Require valid-user
 *     </Proxy>
 * </VirtualHost>
 *
 * and
 * # htpasswd -c password.file proxy
 * New Password: proxy-pass
 *
 */
public class TesterWebSocketClientProxy extends WebSocketBaseTest {

    private static final String MESSAGE_STRING = "proxy-test-message";

    private static final String PROXY_ADDRESS = "192.168.0.200";
    private static final String PROXY_PORT_NO_AUTH = "8888";
    private static final String PROXY_PORT_AUTH = "8889";
    // The IP address of the test instance that is reachable from the proxy
    private static final String TOMCAT_ADDRESS = "192.168.0.100";

    private static final String TOMCAT_USER = "tomcat";
    private static final String TOMCAT_PASSWORD = "tomcat-pass";
    private static final String TOMCAT_ROLE = "tomcat-role";

    private static final String PROXY_USER = "proxy";
    private static final String PROXY_PASSWORD = "proxy-pass";

    @Test
    public void testConnectToServerViaProxyWithNoAuthentication() throws Exception {
        doTestConnectToServerViaProxy(false, false);
    }


    @Test
    public void testConnectToServerViaProxyWithServerAuthentication() throws Exception {
        doTestConnectToServerViaProxy(true, false);
    }


    @Test
    public void testConnectToServerViaProxyWithProxyAuthentication() throws Exception {
        doTestConnectToServerViaProxy(false, true);
    }


    @Test
    public void testConnectToServerViaProxyWithServerAndProxyAuthentication() throws Exception {
        doTestConnectToServerViaProxy(true, true);
    }


    private void doTestConnectToServerViaProxy(boolean serverAuthentication, boolean proxyAuthentication)
            throws Exception {

        // Configure the proxy
        System.setProperty("http.proxyHost", PROXY_ADDRESS);
        if (proxyAuthentication) {
            System.setProperty("http.proxyPort", PROXY_PORT_AUTH);
        } else {
            System.setProperty("http.proxyPort", PROXY_PORT_NO_AUTH);
        }

        Tomcat tomcat = getTomcatInstance();

        // Need to listen on all addresses, not just loop-back
        tomcat.getConnector().setProperty("address", "0.0.0.0");

        // No file system docBase required
        Context ctx = getProgrammaticRootContext();
        ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
        Tomcat.addServlet(ctx, "default", new DefaultServlet());
        ctx.addServletMappingDecoded("/", "default");

        if (serverAuthentication) {
            // Configure Realm
            tomcat.addUser(TOMCAT_USER, TOMCAT_PASSWORD);
            tomcat.addRole(TOMCAT_USER, TOMCAT_ROLE);

            // Configure security constraints
            SecurityCollection securityCollection = new SecurityCollection();
            securityCollection.addPatternDecoded("/*");
            SecurityConstraint securityConstraint = new SecurityConstraint();
            securityConstraint.addAuthRole(TOMCAT_ROLE);
            securityConstraint.addCollection(securityCollection);
            ctx.addConstraint(securityConstraint);

            // Configure authenticator
            LoginConfig loginConfig = new LoginConfig();
            loginConfig.setAuthMethod(BasicAuthenticator.schemeName);
            ctx.setLoginConfig(loginConfig);
            AuthenticatorBase basicAuthenticator = new org.apache.catalina.authenticator.BasicAuthenticator();
            ctx.getPipeline().addValve(basicAuthenticator);
        }

        tomcat.start();

        WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();

        ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
        // Configure the client
        if (serverAuthentication) {
            clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_USER_NAME, TOMCAT_USER);
            clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_PASSWORD, TOMCAT_PASSWORD);
        }
        if (proxyAuthentication) {
            clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_PROXY_USER_NAME, PROXY_USER);
            clientEndpointConfig.getUserProperties().put(Constants.WS_AUTHENTICATION_PROXY_PASSWORD, PROXY_PASSWORD);
        }

        Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
                new URI("ws://" + TOMCAT_ADDRESS + ":" + getPort() + TesterEchoServer.Config.PATH_ASYNC));
        CountDownLatch latch = new CountDownLatch(1);
        BasicText handler = new BasicText(latch);
        wsSession.addMessageHandler(handler);
        wsSession.getBasicRemote().sendText(MESSAGE_STRING);

        boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);

        Assert.assertTrue(latchResult);

        Queue<String> messages = handler.getMessages();
        Assert.assertEquals(1, messages.size());
        Assert.assertEquals(MESSAGE_STRING, messages.peek());
    }
}
