File: websocket_handshake_stream_base.cc

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (166 lines) | stat: -rw-r--r-- 5,887 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/websockets/websocket_handshake_stream_base.h"

#include <stddef.h>

#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/websockets/websocket_extension.h"
#include "net/websockets/websocket_extension_parser.h"
#include "net/websockets/websocket_handshake_constants.h"

namespace net {

namespace {

size_t AddVectorHeaderIfNonEmpty(const char* name,
                                 const std::vector<std::string>& value,
                                 HttpRequestHeaders* headers) {
  if (value.empty()) {
    return 0u;
  }
  std::string joined = base::JoinString(value, ", ");
  const size_t size = joined.size();
  headers->SetHeader(name, std::move(joined));
  return size;
}

}  // namespace

// static
std::string WebSocketHandshakeStreamBase::MultipleHeaderValuesMessage(
    const std::string& header_name) {
  return base::StrCat(
      {"'", header_name,
       "' header must not appear more than once in a response"});
}

// static
void WebSocketHandshakeStreamBase::AddVectorHeaders(
    const std::vector<std::string>& extensions,
    const std::vector<std::string>& protocols,
    HttpRequestHeaders* headers) {
  AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, extensions,
                            headers);
  const size_t protocol_header_size = AddVectorHeaderIfNonEmpty(
      websockets::kSecWebSocketProtocol, protocols, headers);
  base::UmaHistogramCounts10000("Net.WebSocket.ProtocolHeaderSize",
                                protocol_header_size);
}

// static
bool WebSocketHandshakeStreamBase::ValidateSubProtocol(
    const HttpResponseHeaders* headers,
    const std::vector<std::string>& requested_sub_protocols,
    std::string* sub_protocol,
    std::string* failure_message) {
  size_t iter = 0;
  std::optional<std::string> value;
  while (std::optional<std::string_view> maybe_value = headers->EnumerateHeader(
             &iter, websockets::kSecWebSocketProtocol)) {
    if (value) {
      *failure_message =
          MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
      return false;
    }
    if (requested_sub_protocols.empty()) {
      *failure_message =
          base::StrCat({"Response must not include 'Sec-WebSocket-Protocol' "
                        "header if not present in request: ",
                        *maybe_value});
      return false;
    }
    auto it = std::ranges::find(requested_sub_protocols, *maybe_value);
    if (it == requested_sub_protocols.end()) {
      *failure_message =
          base::StrCat({"'Sec-WebSocket-Protocol' header value '", *maybe_value,
                        "' in response does not match any of sent values"});
      return false;
    }
    value = *maybe_value;
  }

  if (!requested_sub_protocols.empty() && !value.has_value()) {
    *failure_message =
        "Sent non-empty 'Sec-WebSocket-Protocol' header "
        "but no response was received";
    return false;
  }
  if (value) {
    *sub_protocol = *value;
  } else {
    sub_protocol->clear();
  }
  return true;
}

// static
bool WebSocketHandshakeStreamBase::ValidateExtensions(
    const HttpResponseHeaders* headers,
    std::string* accepted_extensions_descriptor,
    std::string* failure_message,
    WebSocketExtensionParams* params) {
  size_t iter = 0;
  std::vector<std::string> header_values;
  // TODO(ricea): If adding support for additional extensions, generalise this
  // code.
  bool seen_permessage_deflate = false;
  while (std::optional<std::string_view> header_value =
             headers->EnumerateHeader(&iter,
                                      websockets::kSecWebSocketExtensions)) {
    const std::vector<WebSocketExtension> extensions =
        ParseWebSocketExtensions(*header_value);
    if (extensions.empty()) {
      // TODO(yhirano) Set appropriate failure message.
      *failure_message =
          base::StrCat({"'Sec-WebSocket-Extensions' header value is "
                        "rejected by the parser: ",
                        *header_value});
      return false;
    }

    for (const auto& extension : extensions) {
      if (extension.name() == "permessage-deflate") {
        if (seen_permessage_deflate) {
          *failure_message = "Received duplicate permessage-deflate response";
          return false;
        }
        seen_permessage_deflate = true;
        auto& deflate_parameters = params->deflate_parameters;
        if (!deflate_parameters.Initialize(extension, failure_message) ||
            !deflate_parameters.IsValidAsResponse(failure_message)) {
          *failure_message = "Error in permessage-deflate: " + *failure_message;
          return false;
        }
        // Note that we don't have to check the request-response compatibility
        // here because we send a request compatible with any valid responses.
        // TODO(yhirano): Place a DCHECK here.

        header_values.emplace_back(*header_value);
      } else {
        *failure_message = "Found an unsupported extension '" +
                           extension.name() +
                           "' in 'Sec-WebSocket-Extensions' header";
        return false;
      }
    }
  }
  *accepted_extensions_descriptor = base::JoinString(header_values, ", ");
  params->deflate_enabled = seen_permessage_deflate;
  return true;
}

void WebSocketHandshakeStreamBase::RecordHandshakeResult(
    HandshakeResult result) {
  UMA_HISTOGRAM_ENUMERATION("Net.WebSocket.HandshakeResult2", result,
                            HandshakeResult::NUM_HANDSHAKE_RESULT_TYPES);
}

}  // namespace net