1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
|
--[[
Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]] --
--[[
Web search context module for LLM-based spam detection
This module extracts domains from email URLs, queries a search API to fetch
relevant information about those domains, and formats the results as context
for LLM-based classification.
Main function:
- fetch_and_format(task, redis_params, opts, callback, debug_module): Fetch search context and format for LLM
Options (all optional with safe defaults):
enabled: boolean (default: false)
search_url: string (default: "https://leta.mullvad.net/api/search")
max_domains: number (default: 3) - max domains to search
max_results_per_query: number (default: 3) - max results per domain
timeout: number (default: 5) - HTTP request timeout in seconds
cache_ttl: number (default: 3600) - cache TTL in seconds
cache_key_prefix: string (default: "gpt_search")
as_system: boolean (default: true) - inject as system message vs user message
enable_expression: table - optional gating expression
disable_expression: table - optional negative gating expression
]]
local N = 'llm_search_context'
local M = {}
local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local lua_util = require "lua_util"
local lua_cache = require "lua_cache"
local lua_mime = require "lua_mime"
local ucl = require "ucl"
local DEFAULTS = {
enabled = false,
search_url = "https://leta.mullvad.net/search/__data.json",
search_engine = "brave", -- Search engine to use (brave, google, etc.)
max_domains = 3,
max_results_per_query = 3,
timeout = 5,
cache_ttl = 3600, -- 1 hour
cache_key_prefix = "llm_search",
as_system = true,
enable_expression = nil,
disable_expression = nil,
}
-- Extract unique domains from task URLs, prioritizing CTA (call-to-action) links
local function extract_domains(task, max_domains, debug_module)
local Np = debug_module or N
local domains = {}
local seen = {}
-- Skip common domains that won't provide useful context
local skip_domains = {
['localhost'] = true,
['127.0.0.1'] = true,
['example.com'] = true,
['example.org'] = true,
}
-- First, try to get CTA URLs from HTML (most relevant for spam detection)
-- Uses button weight and HTML structure analysis from C code
local cta_urls = {}
local sel_part = lua_mime.get_displayed_text_part(task)
if sel_part then
cta_urls = sel_part:get_cta_urls()
end
lua_util.debugm(Np, task,
"CTA analysis found %d URLs across", #cta_urls)
for _, url in ipairs(cta_urls) do
if #domains >= max_domains then
break
end
local host = url:get_host()
if host and not skip_domains[host:lower()] and not seen[host] then
seen[host] = true
table.insert(domains, host)
lua_util.debugm(Np, task, "added CTA domain: %s", host)
end
end
-- If we don't have enough domains from CTA, get more from content URLs
if #domains < max_domains then
lua_util.debugm(Np, task, "need more domains (%d/%d), extracting from content URLs",
#domains, max_domains)
local urls = lua_util.extract_specific_urls({
task = task,
limit = max_domains * 3,
esld_limit = max_domains,
need_content = true, -- Content URLs (buttons, links in text)
need_images = false,
}) or {}
lua_util.debugm(Np, task, "extracted %d content URLs", #urls)
for _, url in ipairs(urls) do
if #domains >= max_domains then
break
end
local host = url:get_host()
if host and not seen[host] and not skip_domains[host:lower()] then
seen[host] = true
table.insert(domains, host)
lua_util.debugm(Np, task, "added content domain: %s", host)
end
end
end
-- Still need more? Get from any URLs
if #domains < max_domains then
lua_util.debugm(Np, task, "still need more domains (%d/%d), extracting from all URLs",
#domains, max_domains)
local urls = lua_util.extract_specific_urls({
task = task,
limit = max_domains * 3,
esld_limit = max_domains,
}) or {}
lua_util.debugm(Np, task, "extracted %d all URLs", #urls)
for _, url in ipairs(urls) do
if #domains >= max_domains then
break
end
local host = url:get_host()
if host and not seen[host] and not skip_domains[host:lower()] then
seen[host] = true
table.insert(domains, host)
lua_util.debugm(Np, task, "added general domain: %s", host)
end
end
end
return domains
end
-- Query search API for a single domain
local function query_search_api(task, domain, opts, callback, debug_module)
local Np = debug_module or N
-- Prepare search query for Leta Mullvad API
local query_params = {
q = domain,
engine = opts.search_engine,
}
-- Build query string
local query_string = ""
for k, v in pairs(query_params) do
if query_string ~= "" then
query_string = query_string .. "&"
end
query_string = query_string .. k .. "=" .. lua_util.url_encode_string(v)
end
local full_url = opts.search_url .. "?" .. query_string
local function http_callback(err, code, body, _)
if err then
lua_util.debugm(Np, task, "search API error for domain '%s': %s", domain, err)
callback(nil, domain, err)
return
end
if code ~= 200 then
rspamd_logger.infox(task, "search API returned code %s for domain '%s', url: %s, body: %s",
code, domain, full_url, body and body:sub(1, 200) or 'nil')
callback(nil, domain, string.format("HTTP %s", code))
return
end
lua_util.debugm(Np, task, "search API success for domain '%s', url: %s", domain, full_url)
-- Parse Leta Mullvad JSON response
local parser = ucl.parser()
local ok, parse_err = parser:parse_string(body)
if not ok then
rspamd_logger.errx(task, "%s: failed to parse search API response for %s: %s",
Np, domain, parse_err)
callback(nil, domain, parse_err)
return
end
local data = parser:get_object()
-- Extract search results from Leta Mullvad's nested structure
-- Structure: data.nodes[3].data is a flat array with indices as pointers
-- data[1] = metadata with pointers, data[5] = items array (Lua 1-indexed)
local search_results = { results = {} }
if data and data.nodes and type(data.nodes) == 'table' and #data.nodes >= 3 then
local search_node = data.nodes[3] -- Third node contains search data (Lua 1-indexed)
if search_node and search_node.data and type(search_node.data) == 'table' then
local flat_data = search_node.data
local metadata = flat_data[1]
lua_util.debugm(Np, task, "parsing domain '%s': flat_data has %d elements, metadata type: %s",
domain, #flat_data, type(metadata))
if metadata and metadata.items and type(metadata.items) == 'number' then
-- metadata.items is a 0-indexed pointer, add 1 for Lua
local items_idx = metadata.items + 1
local items = flat_data[items_idx]
if items and type(items) == 'table' then
lua_util.debugm(Np, task, "found %d item indices for domain '%s', items_idx=%d",
#items, domain, items_idx)
local count = 0
for _, result_idx in ipairs(items) do
if count >= opts.max_results_per_query then
break
end
-- result_idx is 0-indexed, add 1 for Lua
local result_template_idx = result_idx + 1
local result_template = flat_data[result_template_idx]
if result_template and type(result_template) == 'table' then
-- Extract values using the template's pointers (also 0-indexed)
local link = result_template.link and flat_data[result_template.link + 1]
local snippet = result_template.snippet and flat_data[result_template.snippet + 1]
local title = result_template.title and flat_data[result_template.title + 1]
lua_util.debugm(Np, task, "result %d template: link_idx=%s, snippet_idx=%s, title_idx=%s",
count + 1, tostring(result_template.link), tostring(result_template.snippet),
tostring(result_template.title))
if link or title or snippet then
table.insert(search_results.results, {
title = title or "",
snippet = snippet or "",
url = link or ""
})
count = count + 1
lua_util.debugm(Np, task, "extracted result %d: title='%s', snippet_len=%d",
count, title or "nil", snippet and #snippet or 0)
end
else
lua_util.debugm(Np, task, "result_template at idx %d is not a table: %s",
result_template_idx, type(result_template))
end
end
else
lua_util.debugm(Np, task, "items is not a table for domain '%s', type: %s",
domain, type(items))
end
else
lua_util.debugm(Np, task, "no valid metadata.items for domain '%s'", domain)
end
end
end
lua_util.debugm(Np, task, "extracted %d search results for domain '%s'",
#search_results.results, domain)
callback(search_results, domain, nil)
end
rspamd_http.request({
url = full_url,
timeout = opts.timeout,
callback = http_callback,
task = task,
log_obj = task,
})
end
-- Format search results as context
local function format_search_results(all_results, opts)
if not all_results or #all_results == 0 then
return nil
end
local context_lines = {
"Web search context for domains in email:"
}
for _, domain_result in ipairs(all_results) do
local domain = domain_result.domain
local results = domain_result.results
if results and results.results and #results.results > 0 then
table.insert(context_lines, string.format("\nDomain: %s", domain))
for i, result in ipairs(results.results) do
if i > opts.max_results_per_query then
break
end
local title = result.title or "No title"
local snippet = result.snippet or result.description or "No description"
-- Truncate snippet if too long
if #snippet > 200 then
snippet = snippet:sub(1, 197) .. "..."
end
table.insert(context_lines, string.format(" - %s: %s", title, snippet))
end
else
table.insert(context_lines, string.format("\nDomain: %s - No search results found", domain))
end
end
return table.concat(context_lines, "\n")
end
-- Main function to fetch and format search context
function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
local Np = debug_module or N
-- Apply defaults
opts = lua_util.override_defaults(DEFAULTS, opts or {})
if not opts.enabled then
lua_util.debugm(Np, task, "search context disabled")
callback(task, false, nil)
return
end
-- Extract domains from task
local domains = extract_domains(task, opts.max_domains, Np)
if #domains == 0 then
lua_util.debugm(Np, task, "no domains to search")
callback(task, false, nil)
return
end
lua_util.debugm(Np, task, "final domain list (%d domains) for search: %s",
#domains, table.concat(domains, ", "))
-- Create cache context
local cache_ctx = nil
if redis_params then
cache_ctx = lua_cache.create_cache_context(redis_params, {
cache_prefix = opts.cache_key_prefix,
cache_ttl = opts.cache_ttl,
cache_format = 'messagepack',
cache_hash_len = 16,
cache_use_hashing = true,
}, Np)
end
local pending_queries = #domains
local all_results = {}
-- Callback for each domain query complete
local function domain_complete(domain, results)
pending_queries = pending_queries - 1
if results then
table.insert(all_results, {
domain = domain,
results = results
})
end
if pending_queries == 0 then
-- All queries complete
if #all_results == 0 then
lua_util.debugm(Np, task, "no search results obtained")
callback(task, false, nil)
else
local context_snippet = format_search_results(all_results, opts)
lua_util.debugm(Np, task, "search context formatted (%s bytes)",
context_snippet and #context_snippet or 0)
callback(task, true, context_snippet)
end
end
end
-- Process each domain
for _, domain in ipairs(domains) do
local cache_key = string.format("search:%s:%s", opts.search_engine, domain)
if cache_ctx then
-- Use lua_cache for caching
lua_cache.cache_get(task, cache_key, cache_ctx, opts.timeout,
function()
-- Cache miss - query API
query_search_api(task, domain, opts, function(api_results, d, api_err)
if api_results then
lua_cache.cache_set(task, cache_key, api_results, cache_ctx)
domain_complete(d, api_results)
else
lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err)
domain_complete(d, nil)
end
end, Np)
end,
function(_, err, data)
-- Cache hit or after miss callback
if data and type(data) == 'table' then
lua_util.debugm(Np, task, "cache hit for domain %s", domain)
domain_complete(domain, data)
-- If no data and no error, the miss callback was already invoked
elseif err then
lua_util.debugm(Np, task, "cache error for domain %s: %s", domain, err)
domain_complete(domain, nil)
end
end)
else
-- No Redis, query directly
query_search_api(task, domain, opts, function(api_results, d, api_err)
if not api_results then
lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err)
end
domain_complete(d, api_results)
end, Np)
end
end
end
return M
|