File: lua-base4.cc

package info (click to toggle)
pdns-recursor 5.3.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 11,108 kB
  • sloc: cpp: 109,513; javascript: 20,651; python: 5,657; sh: 5,069; makefile: 780; ansic: 582; xml: 37
file content (343 lines) | stat: -rw-r--r-- 18,228 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
#include "config.h"
#include <cassert>
#include <fstream>
#include <unordered_set>
#include <unordered_map>
#include <typeinfo>
#include <sys/stat.h>
#include "logger.hh"
#include "logging.hh"
#include "iputils.hh"
#include "dnsname.hh"
#include "dnsparser.hh"
#include "dnspacket.hh"
#include "namespaces.hh"
#include "ednssubnet.hh"
#include "lua-base4.hh"
#include "ext/luawrapper/include/LuaContext.hpp"
#include "dns_random.hh"

void BaseLua4::loadFile(const std::string& fname, bool doPostLoad)
{
  std::ifstream ifs(fname);
  if (!ifs) {
    auto ret = errno;
    auto msg = stringerror(ret);
    g_log << Logger::Error << "Unable to read configuration file from '" << fname << "': " << msg << endl;
    throw std::runtime_error(msg);
  }
  loadStream(ifs, doPostLoad);
};

void BaseLua4::loadString(const std::string &script) {
  std::istringstream iss(script);
  loadStream(iss, true);
};

void BaseLua4::includePath(const std::string& directory) {
  std::vector<std::string> vec;
  const std::string& suffix = "lua";
  auto directoryError = pdns::visit_directory(directory, [this, &directory, &suffix, &vec]([[maybe_unused]] ino_t inodeNumber, const std::string_view& name) {
    (void)this;
    if (boost::starts_with(name, ".")) {
      return true; // skip any dots
    }
    if (boost::ends_with(name, suffix)) {
      // build name
      string fullName = directory + "/" + std::string(name);
      // ensure it's readable file
      struct stat statInfo
      {
      };
      if (stat(fullName.c_str(), &statInfo) != 0 || !S_ISREG(statInfo.st_mode)) {
        string msg = fullName + " is not a regular file";
        g_log << Logger::Error << msg << std::endl;
        throw PDNSException(std::move(msg));
      }
      vec.emplace_back(fullName);
    }
    return true;
  });

  if (directoryError) {
    int err = errno;
    string msg = directory + " is not accessible: " + stringerror(err);
    g_log << Logger::Error << msg << std::endl;
    throw PDNSException(std::move(msg));
  }

  std::sort(vec.begin(), vec.end(), CIStringComparePOSIX());

  for(const auto& file: vec) {
    loadFile(file, false);
  }
};

//  By default no features
void BaseLua4::getFeatures(Features &) { }

void BaseLua4::prepareContext() {
  d_lw = std::make_unique<LuaContext>();

  // lua features available
  Features features;
  getFeatures(features);
  d_lw->writeVariable("pdns_features", features);

  // dnsheader
  d_lw->registerFunction<int(dnsheader::*)()>("getID", [](dnsheader& dh) { return ntohs(dh.id); });
  d_lw->registerFunction<bool(dnsheader::*)()>("getCD", [](dnsheader& dh) { return dh.cd; });
  d_lw->registerFunction<bool(dnsheader::*)()>("getTC", [](dnsheader& dh) { return dh.tc; });
  d_lw->registerFunction<bool(dnsheader::*)()>("getRA", [](dnsheader& dh) { return dh.ra; });
  d_lw->registerFunction<bool(dnsheader::*)()>("getAD", [](dnsheader& dh) { return dh.ad; });
  d_lw->registerFunction<bool(dnsheader::*)()>("getAA", [](dnsheader& dh) { return dh.aa; });
  d_lw->registerFunction<bool(dnsheader::*)()>("getRD", [](dnsheader& dh) { return dh.rd; });
  d_lw->registerFunction<int(dnsheader::*)()>("getRCODE", [](dnsheader& dh) { return dh.rcode; });
  d_lw->registerFunction<int(dnsheader::*)()>("getOPCODE", [](dnsheader& dh) { return dh.opcode; });
  d_lw->registerFunction<int(dnsheader::*)()>("getQDCOUNT", [](dnsheader& dh) { return ntohs(dh.qdcount); });
  d_lw->registerFunction<int(dnsheader::*)()>("getANCOUNT", [](dnsheader& dh) { return ntohs(dh.ancount); });
  d_lw->registerFunction<int(dnsheader::*)()>("getNSCOUNT", [](dnsheader& dh) { return ntohs(dh.nscount); });
  d_lw->registerFunction<int(dnsheader::*)()>("getARCOUNT", [](dnsheader& dh) { return ntohs(dh.arcount); });

  // DNSName
  d_lw->writeFunction("newDN", [](const std::string& dom){ return DNSName(dom); });
  d_lw->registerFunction("__lt", &DNSName::operator<);
  d_lw->registerFunction("canonCompare", &DNSName::canonCompare);
  d_lw->registerFunction<DNSName(DNSName::*)(const DNSName&)>("makeRelative", [](const DNSName& name, const DNSName& zone) { return name.makeRelative(zone); });
  d_lw->registerFunction<bool(DNSName::*)(const DNSName&)>("isPartOf", [](const DNSName& name, const DNSName& rhs) { return name.isPartOf(rhs); });
  d_lw->registerFunction("getRawLabels", &DNSName::getRawLabels);
  d_lw->registerFunction<unsigned int(DNSName::*)()>("countLabels", [](const DNSName& name) { return name.countLabels(); });
  d_lw->registerFunction<size_t(DNSName::*)()>("wireLength", [](const DNSName& name) { return name.wirelength(); });
  d_lw->registerFunction<size_t(DNSName::*)()>("wirelength", [](const DNSName& name) { return name.wirelength(); });
  d_lw->registerFunction<bool(DNSName::*)(const std::string&)>("equal", [](const DNSName& lhs, const std::string& rhs) { return lhs==DNSName(rhs); });
  d_lw->registerEqFunction(&DNSName::operator==);
  d_lw->registerToStringFunction<string(DNSName::*)()>([](const DNSName&dn ) { return dn.toString(); });
  d_lw->registerFunction<string(DNSName::*)()>("toString", [](const DNSName&dn ) { return dn.toString(); });
  d_lw->registerFunction<string(DNSName::*)()>("toStringNoDot", [](const DNSName&dn ) { return dn.toStringNoDot(); });
  d_lw->registerFunction<bool(DNSName::*)()>("chopOff", [](DNSName&dn ) { return dn.chopOff(); });

  // DNSResourceRecord
  d_lw->writeFunction("newDRR", [](const DNSName& qname, const string& qtype, const unsigned int ttl, const string& content, boost::optional<int> domain_id, boost::optional<int> auth){
    auto drr = DNSResourceRecord();
    drr.qname = qname;
    drr.qtype = qtype;
    drr.ttl = ttl;
    drr.setContent(content);
    if (domain_id)
      drr.domain_id = *domain_id;
    if (auth)
      drr.auth = *auth;
     return drr;
  });
  d_lw->registerEqFunction(&DNSResourceRecord::operator==);
  d_lw->registerFunction("__lt", &DNSResourceRecord::operator<);
  d_lw->registerToStringFunction<string(DNSResourceRecord::*)()>([](const DNSResourceRecord& rec) { return rec.getZoneRepresentation(); });
  d_lw->registerFunction<string(DNSResourceRecord::*)()>("toString", [](const DNSResourceRecord& rec) { return rec.getZoneRepresentation();} );
  d_lw->registerFunction<DNSName(DNSResourceRecord::*)()>("qname", [](DNSResourceRecord& rec) { return rec.qname; });
  d_lw->registerFunction<DNSName(DNSResourceRecord::*)()>("wildcardName", [](DNSResourceRecord& rec) { return rec.wildcardname; });
  d_lw->registerFunction<string(DNSResourceRecord::*)()>("content", [](DNSResourceRecord& rec) { return rec.content; });
  d_lw->registerFunction<time_t(DNSResourceRecord::*)()>("lastModified", [](DNSResourceRecord& rec) { return rec.last_modified; });
  d_lw->registerFunction<uint32_t(DNSResourceRecord::*)()>("ttl", [](DNSResourceRecord& rec) { return rec.ttl; });
  d_lw->registerFunction<uint32_t(DNSResourceRecord::*)()>("signttl", [](DNSResourceRecord& rec) { return rec.signttl; });
  d_lw->registerFunction<int(DNSResourceRecord::*)()>("domainId", [](DNSResourceRecord& rec) { return rec.domain_id; });
  d_lw->registerFunction<uint16_t(DNSResourceRecord::*)()>("qtype", [](DNSResourceRecord& rec) { return rec.qtype.getCode(); });
  d_lw->registerFunction<uint16_t(DNSResourceRecord::*)()>("qclass", [](DNSResourceRecord& rec) { return rec.qclass; });
  d_lw->registerFunction<uint8_t(DNSResourceRecord::*)()>("scopeMask", [](DNSResourceRecord& rec) { return rec.scopeMask; });
  d_lw->registerFunction<bool(DNSResourceRecord::*)()>("auth", [](DNSResourceRecord& rec) { return rec.auth; });
  d_lw->registerFunction<bool(DNSResourceRecord::*)()>("disabled", [](DNSResourceRecord& rec) { return rec.disabled; });

  // ComboAddress
  d_lw->registerFunction<bool(ComboAddress::*)()>("isIPv4", [](const ComboAddress& addr) { return addr.sin4.sin_family == AF_INET; });
  d_lw->registerFunction<bool(ComboAddress::*)()>("isIPv6", [](const ComboAddress& addr) { return addr.sin4.sin_family == AF_INET6; });
  d_lw->registerFunction<uint16_t(ComboAddress::*)()>("getPort", [](const ComboAddress& addr) { return ntohs(addr.sin4.sin_port); } );
  d_lw->registerFunction<bool(ComboAddress::*)()>("isMappedIPv4", [](const ComboAddress& addr) { return addr.isMappedIPv4(); });
  d_lw->registerFunction<ComboAddress(ComboAddress::*)()>("mapToIPv4", [](const ComboAddress& addr) { return addr.mapToIPv4(); });
  d_lw->registerFunction<void(ComboAddress::*)(unsigned int)>("truncate", [](ComboAddress& addr, unsigned int bits) { addr.truncate(bits); });
  d_lw->registerFunction<string(ComboAddress::*)()>("toString", [](const ComboAddress& addr) { return addr.toString(); });
  d_lw->registerToStringFunction<string(ComboAddress::*)()>([](const ComboAddress& addr) { return addr.toString(); });
  d_lw->registerFunction<string(ComboAddress::*)()>("toStringWithPort", [](const ComboAddress& addr) { return addr.toStringWithPort(); });
  d_lw->registerFunction<string(ComboAddress::*)()>("getRaw", [](const ComboAddress& addr) { return addr.toByteString(); });

  d_lw->writeFunction("newCA", [](const std::string& a) { return ComboAddress(a); });
  d_lw->writeFunction("newCAFromRaw", [](const std::string& raw, boost::optional<uint16_t> port) {
                                        if (raw.size() == 4) {
                                          struct sockaddr_in sin4;
                                          memset(&sin4, 0, sizeof(sin4));
                                          sin4.sin_family = AF_INET;
                                          memcpy(&sin4.sin_addr.s_addr, raw.c_str(), raw.size());
                                          if (port) {
                                            sin4.sin_port = htons(*port);
                                          }
                                          return ComboAddress(&sin4);
                                        }
                                        else if (raw.size() == 16) {
                                          struct sockaddr_in6 sin6;
                                          memset(&sin6, 0, sizeof(sin6));
                                          sin6.sin6_family = AF_INET6;
                                          memcpy(&sin6.sin6_addr.s6_addr, raw.c_str(), raw.size());
                                          if (port) {
                                            sin6.sin6_port = htons(*port);
                                          }
                                          return ComboAddress(&sin6);
                                        }
                                        return ComboAddress();
                                      });
  typedef std::unordered_set<ComboAddress,ComboAddress::addressOnlyHash,ComboAddress::addressOnlyEqual> cas_t;
  d_lw->registerFunction<bool(ComboAddress::*)(const ComboAddress&)>("equal", [](const ComboAddress& lhs, const ComboAddress& rhs) { return ComboAddress::addressOnlyEqual()(lhs, rhs); });

  // cas_t
  d_lw->writeFunction("newCAS", []{ return cas_t(); });
  d_lw->registerFunction<void(cas_t::*)(boost::variant<string,ComboAddress, vector<pair<unsigned int,string> > >)>("add",
    [](cas_t& cas, const boost::variant<string,ComboAddress,vector<pair<unsigned int,string> > >& in)
    {
      try {
      if(auto s = boost::get<string>(&in)) {
        cas.insert(ComboAddress(*s));
      }
      else if(auto v = boost::get<vector<pair<unsigned int, string> > >(&in)) {
        for(const auto& str : *v)
          cas.insert(ComboAddress(str.second));
      }
      else
        cas.insert(boost::get<ComboAddress>(in));
      }
      catch(std::exception& e) {
        SLOG(g_log <<Logger::Error<<e.what()<<endl,
             g_slog->withName("lua")->error(Logr::Error, e.what(), "Exception in newCAS", "exception", Logging::Loggable("std::exception")));
      }
    });
  d_lw->registerFunction<bool(cas_t::*)(const ComboAddress&)>("check",[](const cas_t& cas, const ComboAddress&ca) { return cas.count(ca)>0; });

  // QType
  d_lw->writeFunction("newQType", [](const string& s) { QType q; q = s; return q; });
  d_lw->registerFunction("getCode", &QType::getCode);
  d_lw->registerFunction("getName", &QType::toString);
  d_lw->registerEqFunction<bool(QType::*)(const QType&)>([](const QType& a, const QType& b){ return a == b;}); // operator overloading confuses LuaContext
  d_lw->registerToStringFunction(&QType::toString);

  // Netmask
  d_lw->writeFunction("newNetmask", [](const string& s) { return Netmask(s); });
  d_lw->registerFunction<ComboAddress(Netmask::*)()>("getNetwork", [](const Netmask& nm) { return nm.getNetwork(); } ); // const reference makes this necessary
  d_lw->registerFunction<ComboAddress(Netmask::*)()>("getMaskedNetwork", [](const Netmask& nm) { return nm.getMaskedNetwork(); } );
  d_lw->registerFunction("isIpv4", &Netmask::isIPv4);
  d_lw->registerFunction("isIPv4", &Netmask::isIPv4);
  d_lw->registerFunction("isIpv6", &Netmask::isIPv6);
  d_lw->registerFunction("isIPv6", &Netmask::isIPv6);
  d_lw->registerFunction("getBits", &Netmask::getBits);
  d_lw->registerFunction("toString", &Netmask::toString);
  d_lw->registerFunction("empty", &Netmask::empty);
  d_lw->registerFunction("match", (bool (Netmask::*)(const string&) const)&Netmask::match);
  d_lw->registerEqFunction(&Netmask::operator==);
  d_lw->registerToStringFunction(&Netmask::toString);

  // NetmaskGroup
  d_lw->writeFunction("newNMG", [](boost::optional<vector<pair<unsigned int, std::string>>> masks) {
    auto nmg = NetmaskGroup();

    if (masks) {
      for(const auto& mask: *masks) {
        nmg.addMask(mask.second);
      }
    }

    return nmg;
  });
  // d_lw->writeFunction("newNMG", []() { return NetmaskGroup(); });
  d_lw->registerFunction<void(NetmaskGroup::*)(const std::string&mask)>("addMask", [](NetmaskGroup&nmg, const std::string& mask) { nmg.addMask(mask); });
  d_lw->registerFunction<void(NetmaskGroup::*)(const vector<pair<unsigned int, std::string>>&)>("addMasks", [](NetmaskGroup&nmg, const vector<pair<unsigned int, std::string>>& masks) { for(const auto& mask: masks) { nmg.addMask(mask.second); } });
  d_lw->registerFunction("match", (bool (NetmaskGroup::*)(const ComboAddress&) const)&NetmaskGroup::match);

  // DNSRecord
  d_lw->writeFunction("newDR", [](const DNSName& name, const std::string& type, unsigned int ttl, const std::string& content, int place) { QType qtype; qtype = type; auto dr = DNSRecord(); dr.d_name = name; dr.d_type = qtype.getCode(); dr.d_ttl = ttl; dr.setContent(shared_ptr<DNSRecordContent>(DNSRecordContent::make(dr.d_type, QClass::IN, content))); dr.d_place = static_cast<DNSResourceRecord::Place>(place); return dr; });
  d_lw->registerMember("name", &DNSRecord::d_name);
  d_lw->registerMember("type", &DNSRecord::d_type);
  d_lw->registerMember("ttl", &DNSRecord::d_ttl);
  d_lw->registerMember("place", &DNSRecord::d_place);
  d_lw->registerFunction<string(DNSRecord::*)()>("getContent", [](const DNSRecord& dr) { return dr.getContent()->getZoneRepresentation(); });
  d_lw->registerFunction<boost::optional<ComboAddress>(DNSRecord::*)()>("getCA", [](const DNSRecord& dr) {
      boost::optional<ComboAddress> ret;

      if(auto arec = getRR<ARecordContent>(dr))
        ret=arec->getCA(53);
      else if(auto aaaarec = getRR<AAAARecordContent>(dr))
        ret=aaaarec->getCA(53);
      return ret;
    });
  d_lw->registerFunction<void (DNSRecord::*)(const std::string&)>("changeContent", [](DNSRecord& dr, const std::string& newContent) { dr.setContent(shared_ptr<DNSRecordContent>(DNSRecordContent::make(dr.d_type, 1, newContent))); });

  // pdnslog
#ifdef RECURSOR
  d_lw->writeFunction("pdnslog", [](const std::string& msg, boost::optional<int> loglevel, boost::optional<std::map<std::string, std::string>> values) {
    auto log = g_slog->withName("lua");
    if (values) {
      for (const auto& [key, value] : *values) {
        log = log->withValues(key, Logging::Loggable(value));
      }
    }
    log->info(static_cast<Logr::Priority>(loglevel.get_value_or(Logr::Warning)), msg);
#else
    d_lw->writeFunction("pdnslog", [](const std::string& msg, boost::optional<int> loglevel) {
      g_log << (Logger::Urgency)loglevel.get_value_or(Logger::Warning) << msg<<endl;
#endif
  });

  d_lw->writeFunction("pdnsrandom", [](boost::optional<uint32_t> maximum) {
    return maximum ? dns_random(*maximum) : dns_random_uint32();
  });

  // certain constants

  vector<pair<string, int> > rcodes = {{"NOERROR",  RCode::NoError  },
                                       {"FORMERR",  RCode::FormErr  },
                                       {"SERVFAIL", RCode::ServFail },
                                       {"NXDOMAIN", RCode::NXDomain },
                                       {"NOTIMP",   RCode::NotImp   },
                                       {"REFUSED",  RCode::Refused  },
                                       {"YXDOMAIN", RCode::YXDomain },
                                       {"YXRRSET",  RCode::YXRRSet  },
                                       {"NXRRSET",  RCode::NXRRSet  },
                                       {"NOTAUTH",  RCode::NotAuth  },
                                       {"NOTZONE",  RCode::NotZone  },
                                       {"DROP",    -2               }}; // To give backport-incompatibility warning
  for(const auto& rcode : rcodes)
    d_pd.push_back({rcode.first, rcode.second});

  d_pd.push_back({"place", in_t{
    {"QUESTION", 0},
    {"ANSWER", 1},
    {"AUTHORITY", 2},
    {"ADDITIONAL", 3}
  }});

  d_pd.push_back({"loglevels", in_t{
        {"Alert", LOG_ALERT},
        {"Critical", LOG_CRIT},
        {"Debug", LOG_DEBUG},
        {"Emergency", LOG_EMERG},
        {"Info", LOG_INFO},
        {"Notice", LOG_NOTICE},
        {"Warning", LOG_WARNING},
        {"Error", LOG_ERR}
          }});

  for(const auto& n : QType::names)
    d_pd.push_back({n.first, n.second});

  d_lw->registerMember("tv_sec", &timeval::tv_sec);
  d_lw->registerMember("tv_usec", &timeval::tv_usec);

  postPrepareContext();

  // so we can let postprepare do changes to this
  d_lw->writeVariable("pdns", d_pd);
}

void BaseLua4::loadStream(std::istream &stream, bool doPostLoad) {
  d_lw->executeCode(stream);

  if (doPostLoad) {
    postLoad();
  }
}

BaseLua4::~BaseLua4() = default;