1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
|
// -*- mode: cpp; mode: fold -*-
// Description /*{{{*/
/* ######################################################################
SRV record support
##################################################################### */
/*}}}*/
#include <config.h>
#include <netdb.h>
#include <ctime>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <netinet/in.h>
#include <resolv.h>
#include <array>
#include <algorithm>
#include <memory>
#include <tuple>
#include <apt-pkg/configuration.h>
#include <apt-pkg/error.h>
#include <apt-pkg/strutl.h>
#include "srvrec.h"
bool SrvRec::operator==(SrvRec const &other) const
{
return (std::tie(target, priority, weight, port) ==
std::tie(other.target, other.priority, other.weight, other.port));
}
bool GetSrvRecords(std::string host, int port, std::vector<SrvRec> &Result)
{
// try SRV only for hostnames, not for IP addresses
{
struct in_addr addr4;
struct in6_addr addr6;
if (inet_pton(AF_INET, host.c_str(), &addr4) == 1 ||
inet_pton(AF_INET6, host.c_str(), &addr6) == 1)
return true;
}
std::string target;
int res;
struct servent s_ent_buf;
struct servent *s_ent = nullptr;
std::array<char, 1024> buf;
res = getservbyport_r(htons(port), "tcp", &s_ent_buf, buf.data(), buf.size(), &s_ent);
if (res != 0 || s_ent == nullptr)
return false;
strprintf(target, "_%s._tcp.%s", s_ent->s_name, host.c_str());
return GetSrvRecords(target, Result);
}
bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
{
unsigned char answer[PACKETSZ];
int answer_len, compressed_name_len;
int answer_count;
#if __RES >= 19991006
struct __res_state res;
if (res_ninit(&res) != 0)
return _error->Errno("res_init", "Failed to init resolver");
// Close on return
std::shared_ptr<void> guard(&res, res_nclose);
answer_len = res_nquery(&res, name.c_str(), C_IN, T_SRV, answer, sizeof(answer));
#else
if (res_init() != 0)
return _error->Errno("res_init", "Failed to init resolver");
answer_len = res_query(name.c_str(), C_IN, T_SRV, answer, sizeof(answer));
#endif //__RES >= 19991006
if (answer_len == -1)
return false;
if (answer_len < (int)sizeof(HEADER))
return _error->Warning("Not enough data from res_query (%i)", answer_len);
// check the header
HEADER *header = (HEADER*)answer;
if (header->rcode != NOERROR)
return _error->Warning("res_query returned rcode %i", header->rcode);
answer_count = ntohs(header->ancount);
if (answer_count <= 0)
return _error->Warning("res_query returned no answers (%i) ", answer_count);
// skip the header
compressed_name_len = dn_skipname(answer+sizeof(HEADER), answer+answer_len);
if(compressed_name_len < 0)
return _error->Warning("dn_skipname failed %i", compressed_name_len);
// pt points to the first answer record, go over all of them now
unsigned char *pt = answer+sizeof(HEADER)+compressed_name_len+QFIXEDSZ;
while ((int)Result.size() < answer_count && pt < answer+answer_len)
{
u_int16_t type, klass, priority, weight, port, dlen;
char buf[MAXDNAME];
compressed_name_len = dn_skipname(pt, answer+answer_len);
if (compressed_name_len < 0)
return _error->Warning("dn_skipname failed (2): %i",
compressed_name_len);
pt += compressed_name_len;
if (((answer+answer_len) - pt) < 16)
return _error->Warning("packet too short");
// extract the data out of the result buffer
#define extract_u16(target, p) target = *p++ << 8; target |= *p++;
extract_u16(type, pt);
if(type != T_SRV)
return _error->Warning("Unexpected type excepted %x != %x",
T_SRV, type);
extract_u16(klass, pt);
if(klass != C_IN)
return _error->Warning("Unexpected class excepted %x != %x",
C_IN, klass);
pt += 4; // ttl
extract_u16(dlen, pt);
extract_u16(priority, pt);
extract_u16(weight, pt);
extract_u16(port, pt);
#undef extract_u16
compressed_name_len = dn_expand(answer, answer+answer_len, pt, buf, sizeof(buf));
if(compressed_name_len < 0)
return _error->Warning("dn_expand failed %i", compressed_name_len);
pt += compressed_name_len;
// add it to our class
Result.emplace_back(buf, priority, weight, port);
}
// implement load balancing as specified in RFC-2782
// sort them by priority
std::stable_sort(Result.begin(), Result.end());
if (_config->FindB("Debug::Acquire::SrvRecs", false))
for(auto const &R : Result)
std::cerr << "SrvRecs: got " << R.target
<< " prio: " << R.priority
<< " weight: " << R.weight
<< '\n';
return true;
}
SrvRec PopFromSrvRecs(std::vector<SrvRec> &Recs)
{
// FIXME: instead of the simplistic shuffle below use the algorithm
// described in rfc2782 (with weights)
// and figure out how the weights need to be adjusted if
// a host refuses connections
#if 0 // all code below is only needed for the weight adjusted selection
// assign random number ranges
int prev_weight = 0;
int prev_priority = 0;
for(std::vector<SrvRec>::iterator I = Result.begin();
I != Result.end(); ++I)
{
if(prev_priority != I->priority)
prev_weight = 0;
I->random_number_range_start = prev_weight;
I->random_number_range_end = prev_weight + I->weight;
prev_weight = I->random_number_range_end;
prev_priority = I->priority;
if (_config->FindB("Debug::Acquire::SrvRecs", false) == true)
std::cerr << "SrvRecs: got " << I->target
<< " prio: " << I->priority
<< " weight: " << I->weight
<< std::endl;
}
// go over the code in reverse order and note the max random range
int max = 0;
prev_priority = 0;
for(std::vector<SrvRec>::iterator I = Result.end();
I != Result.begin(); --I)
{
if(prev_priority != I->priority)
max = I->random_number_range_end;
I->random_number_range_max = max;
}
#endif
// shuffle in a very simplistic way for now (equal weights)
std::vector<SrvRec>::iterator I = Recs.begin();
std::vector<SrvRec>::iterator const J = std::find_if(Recs.begin(), Recs.end(),
[&I](SrvRec const &J) { return I->priority != J.priority; });
// clock seems random enough.
I += std::max(static_cast<clock_t>(0), clock()) % std::distance(I, J);
SrvRec const selected = std::move(*I);
Recs.erase(I);
if (_config->FindB("Debug::Acquire::SrvRecs", false) == true)
std::cerr << "PopFromSrvRecs: selecting " << selected.target << std::endl;
return selected;
}
|