File: networkaccessmanager.cpp

package info (click to toggle)
libquotient 0.9.5-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,588 kB
  • sloc: xml: 39,103; cpp: 25,226; sh: 97; makefile: 10
file content (200 lines) | stat: -rw-r--r-- 6,816 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// SPDX-FileCopyrightText: 2018 Kitsune Ral <kitsune-ral@users.sf.net>
// SPDX-License-Identifier: LGPL-2.1-or-later

#include "networkaccessmanager.h"

#include "connectiondata.h"
#include "logging_categories_p.h"
#include "mxcreply.h"

#include "events/filesourceinfo.h"
#include "jobs/downloadfilejob.h" // For DownloadFileJob::makeRequestUrl() only

#include <QtCore/QCoreApplication>
#include <QtCore/QReadWriteLock>
#include <QtCore/QSettings>
#include <QtCore/QStringBuilder>
#include <QtCore/QThread>
#include <QtNetwork/QNetworkReply>

using namespace Quotient;

namespace {
class {
public:
    struct ConnectionData {
        QString accountId;
        HomeserverData hsData;
    };

    void addConnection(const QString& accountId, HomeserverData hsData)
    {
        if (hsData.baseUrl.isEmpty())
            return;

        const QWriteLocker _(&namLock);
        if (auto it = std::ranges::find(connectionData, accountId, &ConnectionData::accountId);
            it != connectionData.end())
            it->hsData = std::move(hsData);
        else // Xcode doesn't like emplace_back() below for some reason (anon class?..)
            connectionData.push_back({ accountId, std::move(hsData) });
    }
    void addSpecVersions(QStringView accountId, const QStringList& versions)
    {
        if (versions.isEmpty())
            return;

        const QWriteLocker _(&namLock);
        auto it = std::ranges::find(connectionData, accountId, &ConnectionData::accountId);
        if (QUO_ALARM_X(it == connectionData.end(), "Quotient::NAM: Trying to save supported spec "
                                                    "versions on an inexistent account"))
            return;

        it->hsData.supportedSpecVersions = versions;
    }
    void dropConnection(QStringView accountId)
    {
        const QWriteLocker _(&namLock);
        std::erase_if(connectionData,
                      [&accountId](const ConnectionData& cd) { return cd.accountId == accountId; });
    }
    HomeserverData getConnection(const QString& accountId) const
    {
        const QReadLocker _(&namLock);
        auto it = std::ranges::find(connectionData, accountId, &ConnectionData::accountId);
        return it == connectionData.cend() ? HomeserverData{} : it->hsData;
    }
    void addIgnoredSslError(const QSslError& error)
    {
        const QWriteLocker _(&namLock);
        ignoredSslErrors.push_back(error);
    }
    void clearIgnoredSslErrors()
    {
        const QWriteLocker _(&namLock);
        ignoredSslErrors.clear();
    }
    QList<QSslError> getIgnoredSslErrors() const
    {
        const QReadLocker _(&namLock);
        return ignoredSslErrors;
    }
    void setAccessToken(const QString& userId, const QByteArray& accessToken)
    {
        const QWriteLocker _(&namLock);
        if (auto it = std::ranges::find(connectionData, userId, &ConnectionData::accountId);
            it != connectionData.end()) {
            it->hsData.accessToken = accessToken;
        }
    }

private:
    mutable QReadWriteLock namLock{};
    std::vector<ConnectionData> connectionData{};
    QList<QSslError> ignoredSslErrors{};
} d;

} // anonymous namespace

void NetworkAccessManager::addAccount(const QString& accountId, const QUrl& homeserver,
                                      const QByteArray& accessToken)
{
    Q_ASSERT(!accountId.isEmpty());
    d.addConnection(accountId, { homeserver, accessToken });
}

void NetworkAccessManager::setAccessToken(const QString& userId, const QByteArray& token)
{
    d.setAccessToken(userId, token);
}

void NetworkAccessManager::updateAccountSpecVersions(QStringView accountId,
                                                     const QStringList& versions)
{
    Q_ASSERT(!accountId.isEmpty());
    d.addSpecVersions(accountId, versions);
}

void NetworkAccessManager::dropAccount(QStringView accountId)
{
    d.dropConnection(accountId);
}

QList<QSslError> NetworkAccessManager::ignoredSslErrors()
{
    return d.getIgnoredSslErrors();
}

void NetworkAccessManager::addIgnoredSslError(const QSslError& error)
{
    d.addIgnoredSslError(error);
}

void NetworkAccessManager::clearIgnoredSslErrors()
{
    d.clearIgnoredSslErrors();
}

NetworkAccessManager* NetworkAccessManager::instance()
{
    thread_local auto* nam = [] {
        auto* namInit = new NetworkAccessManager();
        connect(QThread::currentThread(), &QThread::finished, namInit,
                &QObject::deleteLater);
        return namInit;
    }();
    return nam;
}

QNetworkReply* NetworkAccessManager::createRequest(
    Operation op, const QNetworkRequest& request, QIODevice* outgoingData)
{
    const auto url = request.url();
    if (url.scheme() != "mxc"_L1) {
        auto reply =
            QNetworkAccessManager::createRequest(op, request, outgoingData);
        reply->ignoreSslErrors(d.getIgnoredSslErrors());
        return reply;
    }
    const QUrlQuery query{ url.query() };
    const auto accountId = query.queryItemValue(u"user_id"_s);
    if (accountId.isEmpty()) {
        // Using QSettings here because Quotient::NetworkSettings
        // doesn't provide multi-threading guarantees
        if (static thread_local const QSettings s;
            s.value("Network/allow_direct_media_requests"_L1).toBool()) //
        {
            // TODO: Make the best effort with a direct unauthenticated request
            // to the media server
            qCWarning(NETWORK)
                << "Direct unauthenticated mxc requests are not implemented";
            return new MxcReply();
        }
        qCWarning(NETWORK)
            << "No connection specified, cannot convert mxc request";
        return new MxcReply();
    }
    const auto& hsData = d.getConnection(accountId);
    if (!hsData.baseUrl.isValid()) {
        // Strictly speaking, it should be an assert...
        qCCritical(NETWORK) << "Homeserver for" << accountId
                            << "not found, cannot convert mxc request";
        return new MxcReply();
    }

    // Convert mxc:// URL into normal http(s) for the given homeserver
    QNetworkRequest rewrittenRequest(request);
    rewrittenRequest.setUrl(DownloadFileJob::makeRequestUrl(hsData, url));
    rewrittenRequest.setRawHeader("Authorization", "Bearer "_ba + hsData.accessToken);

    auto* implReply = QNetworkAccessManager::createRequest(op, rewrittenRequest);
    implReply->ignoreSslErrors(d.getIgnoredSslErrors());
    const auto& fileMetadata = FileMetadataMap::lookup(query.queryItemValue(u"room_id"_s),
                                                       query.queryItemValue(u"event_id"_s));
    return new MxcReply(implReply, fileMetadata);
}

QStringList NetworkAccessManager::supportedSchemesImplementation() const
{
    return QNetworkAccessManager::supportedSchemesImplementation() << u"mxc"_s;
}