File: llm_search_context.lua

package info (click to toggle)
rspamd 3.14.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 35,064 kB
  • sloc: ansic: 247,728; cpp: 107,741; javascript: 31,385; perl: 3,089; asm: 2,512; pascal: 1,625; python: 1,510; sh: 589; sql: 313; makefile: 195; xml: 74
file content (441 lines) | stat: -rw-r--r-- 14,754 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
--[[
Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]] --

--[[
Web search context module for LLM-based spam detection

This module extracts domains from email URLs, queries a search API to fetch
relevant information about those domains, and formats the results as context
for LLM-based classification.

Main function:
  - fetch_and_format(task, redis_params, opts, callback, debug_module): Fetch search context and format for LLM

Options (all optional with safe defaults):
  enabled: boolean (default: false)
  search_url: string (default: "https://leta.mullvad.net/api/search")
  max_domains: number (default: 3) - max domains to search
  max_results_per_query: number (default: 3) - max results per domain
  timeout: number (default: 5) - HTTP request timeout in seconds
  cache_ttl: number (default: 3600) - cache TTL in seconds
  cache_key_prefix: string (default: "gpt_search")
  as_system: boolean (default: true) - inject as system message vs user message
  enable_expression: table - optional gating expression
  disable_expression: table - optional negative gating expression
]]

local N = 'llm_search_context'

local M = {}

local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local lua_util = require "lua_util"
local lua_cache = require "lua_cache"
local lua_mime = require "lua_mime"
local ucl = require "ucl"

local DEFAULTS = {
  enabled = false,
  search_url = "https://leta.mullvad.net/search/__data.json",
  search_engine = "brave", -- Search engine to use (brave, google, etc.)
  max_domains = 3,
  max_results_per_query = 3,
  timeout = 5,
  cache_ttl = 3600, -- 1 hour
  cache_key_prefix = "llm_search",
  as_system = true,
  enable_expression = nil,
  disable_expression = nil,
}

-- Extract unique domains from task URLs, prioritizing CTA (call-to-action) links
local function extract_domains(task, max_domains, debug_module)
  local Np = debug_module or N
  local domains = {}
  local seen = {}

  -- Skip common domains that won't provide useful context
  local skip_domains = {
    ['localhost'] = true,
    ['127.0.0.1'] = true,
    ['example.com'] = true,
    ['example.org'] = true,
  }

  -- First, try to get CTA URLs from HTML (most relevant for spam detection)
  -- Uses button weight and HTML structure analysis from C code
  local cta_urls = {}
  local sel_part = lua_mime.get_displayed_text_part(task)
  if sel_part then
    cta_urls = sel_part:get_cta_urls()
  end
  lua_util.debugm(Np, task,
      "CTA analysis found %d URLs across", #cta_urls)

  for _, url in ipairs(cta_urls) do
    if #domains >= max_domains then
      break
    end

    local host = url:get_host()
    if host and not skip_domains[host:lower()] and not seen[host] then
      seen[host] = true
      table.insert(domains, host)
      lua_util.debugm(Np, task, "added CTA domain: %s", host)
    end
  end

  -- If we don't have enough domains from CTA, get more from content URLs
  if #domains < max_domains then
    lua_util.debugm(Np, task, "need more domains (%d/%d), extracting from content URLs",
        #domains, max_domains)

    local urls = lua_util.extract_specific_urls({
      task = task,
      limit = max_domains * 3,
      esld_limit = max_domains,
      need_content = true, -- Content URLs (buttons, links in text)
      need_images = false,
    }) or {}

    lua_util.debugm(Np, task, "extracted %d content URLs", #urls)

    for _, url in ipairs(urls) do
      if #domains >= max_domains then
        break
      end

      local host = url:get_host()
      if host and not seen[host] and not skip_domains[host:lower()] then
        seen[host] = true
        table.insert(domains, host)
        lua_util.debugm(Np, task, "added content domain: %s", host)
      end
    end
  end

  -- Still need more? Get from any URLs
  if #domains < max_domains then
    lua_util.debugm(Np, task, "still need more domains (%d/%d), extracting from all URLs",
        #domains, max_domains)

    local urls = lua_util.extract_specific_urls({
      task = task,
      limit = max_domains * 3,
      esld_limit = max_domains,
    }) or {}

    lua_util.debugm(Np, task, "extracted %d all URLs", #urls)

    for _, url in ipairs(urls) do
      if #domains >= max_domains then
        break
      end

      local host = url:get_host()
      if host and not seen[host] and not skip_domains[host:lower()] then
        seen[host] = true
        table.insert(domains, host)
        lua_util.debugm(Np, task, "added general domain: %s", host)
      end
    end
  end

  return domains
end

-- Query search API for a single domain
local function query_search_api(task, domain, opts, callback, debug_module)
  local Np = debug_module or N

  -- Prepare search query for Leta Mullvad API
  local query_params = {
    q = domain,
    engine = opts.search_engine,
  }

  -- Build query string
  local query_string = ""
  for k, v in pairs(query_params) do
    if query_string ~= "" then
      query_string = query_string .. "&"
    end
    query_string = query_string .. k .. "=" .. lua_util.url_encode_string(v)
  end

  local full_url = opts.search_url .. "?" .. query_string

  local function http_callback(err, code, body, _)
    if err then
      lua_util.debugm(Np, task, "search API error for domain '%s': %s", domain, err)
      callback(nil, domain, err)
      return
    end

    if code ~= 200 then
      rspamd_logger.infox(task, "search API returned code %s for domain '%s', url: %s, body: %s",
          code, domain, full_url, body and body:sub(1, 200) or 'nil')
      callback(nil, domain, string.format("HTTP %s", code))
      return
    end

    lua_util.debugm(Np, task, "search API success for domain '%s', url: %s", domain, full_url)

    -- Parse Leta Mullvad JSON response
    local parser = ucl.parser()
    local ok, parse_err = parser:parse_string(body)
    if not ok then
      rspamd_logger.errx(task, "%s: failed to parse search API response for %s: %s",
          Np, domain, parse_err)
      callback(nil, domain, parse_err)
      return
    end

    local data = parser:get_object()

    -- Extract search results from Leta Mullvad's nested structure
    -- Structure: data.nodes[3].data is a flat array with indices as pointers
    -- data[1] = metadata with pointers, data[5] = items array (Lua 1-indexed)
    local search_results = { results = {} }

    if data and data.nodes and type(data.nodes) == 'table' and #data.nodes >= 3 then
      local search_node = data.nodes[3]  -- Third node contains search data (Lua 1-indexed)

      if search_node and search_node.data and type(search_node.data) == 'table' then
        local flat_data = search_node.data
        local metadata = flat_data[1]

        lua_util.debugm(Np, task, "parsing domain '%s': flat_data has %d elements, metadata type: %s",
            domain, #flat_data, type(metadata))

        if metadata and metadata.items and type(metadata.items) == 'number' then
          -- metadata.items is a 0-indexed pointer, add 1 for Lua
          local items_idx = metadata.items + 1
          local items = flat_data[items_idx]

          if items and type(items) == 'table' then
            lua_util.debugm(Np, task, "found %d item indices for domain '%s', items_idx=%d",
                #items, domain, items_idx)

            local count = 0

            for _, result_idx in ipairs(items) do
              if count >= opts.max_results_per_query then
                break
              end

              -- result_idx is 0-indexed, add 1 for Lua
              local result_template_idx = result_idx + 1
              local result_template = flat_data[result_template_idx]

              if result_template and type(result_template) == 'table' then
                -- Extract values using the template's pointers (also 0-indexed)
                local link = result_template.link and flat_data[result_template.link + 1]
                local snippet = result_template.snippet and flat_data[result_template.snippet + 1]
                local title = result_template.title and flat_data[result_template.title + 1]

                lua_util.debugm(Np, task, "result %d template: link_idx=%s, snippet_idx=%s, title_idx=%s",
                    count + 1, tostring(result_template.link), tostring(result_template.snippet),
                    tostring(result_template.title))

                if link or title or snippet then
                  table.insert(search_results.results, {
                    title = title or "",
                    snippet = snippet or "",
                    url = link or ""
                  })
                  count = count + 1
                  lua_util.debugm(Np, task, "extracted result %d: title='%s', snippet_len=%d",
                      count, title or "nil", snippet and #snippet or 0)
                end
              else
                lua_util.debugm(Np, task, "result_template at idx %d is not a table: %s",
                    result_template_idx, type(result_template))
              end
            end
          else
            lua_util.debugm(Np, task, "items is not a table for domain '%s', type: %s",
                domain, type(items))
          end
        else
          lua_util.debugm(Np, task, "no valid metadata.items for domain '%s'", domain)
        end
      end
    end

    lua_util.debugm(Np, task, "extracted %d search results for domain '%s'",
        #search_results.results, domain)
    callback(search_results, domain, nil)
  end

  rspamd_http.request({
    url = full_url,
    timeout = opts.timeout,
    callback = http_callback,
    task = task,
    log_obj = task,
  })
end

-- Format search results as context
local function format_search_results(all_results, opts)
  if not all_results or #all_results == 0 then
    return nil
  end

  local context_lines = {
    "Web search context for domains in email:"
  }

  for _, domain_result in ipairs(all_results) do
    local domain = domain_result.domain
    local results = domain_result.results

    if results and results.results and #results.results > 0 then
      table.insert(context_lines, string.format("\nDomain: %s", domain))

      for i, result in ipairs(results.results) do
        if i > opts.max_results_per_query then
          break
        end

        local title = result.title or "No title"
        local snippet = result.snippet or result.description or "No description"

        -- Truncate snippet if too long
        if #snippet > 200 then
          snippet = snippet:sub(1, 197) .. "..."
        end

        table.insert(context_lines, string.format("  - %s: %s", title, snippet))
      end
    else
      table.insert(context_lines, string.format("\nDomain: %s - No search results found", domain))
    end
  end

  return table.concat(context_lines, "\n")
end

-- Main function to fetch and format search context
function M.fetch_and_format(task, redis_params, opts, callback, debug_module)
  local Np = debug_module or N

  -- Apply defaults
  opts = lua_util.override_defaults(DEFAULTS, opts or {})

  if not opts.enabled then
    lua_util.debugm(Np, task, "search context disabled")
    callback(task, false, nil)
    return
  end

  -- Extract domains from task
  local domains = extract_domains(task, opts.max_domains, Np)

  if #domains == 0 then
    lua_util.debugm(Np, task, "no domains to search")
    callback(task, false, nil)
    return
  end

  lua_util.debugm(Np, task, "final domain list (%d domains) for search: %s",
      #domains, table.concat(domains, ", "))

  -- Create cache context
  local cache_ctx = nil
  if redis_params then
    cache_ctx = lua_cache.create_cache_context(redis_params, {
      cache_prefix = opts.cache_key_prefix,
      cache_ttl = opts.cache_ttl,
      cache_format = 'messagepack',
      cache_hash_len = 16,
      cache_use_hashing = true,
    }, Np)
  end

  local pending_queries = #domains
  local all_results = {}

  -- Callback for each domain query complete
  local function domain_complete(domain, results)
    pending_queries = pending_queries - 1

    if results then
      table.insert(all_results, {
        domain = domain,
        results = results
      })
    end

    if pending_queries == 0 then
      -- All queries complete
      if #all_results == 0 then
        lua_util.debugm(Np, task, "no search results obtained")
        callback(task, false, nil)
      else
        local context_snippet = format_search_results(all_results, opts)
        lua_util.debugm(Np, task, "search context formatted (%s bytes)",
            context_snippet and #context_snippet or 0)
        callback(task, true, context_snippet)
      end
    end
  end

  -- Process each domain
  for _, domain in ipairs(domains) do
    local cache_key = string.format("search:%s:%s", opts.search_engine, domain)

    if cache_ctx then
      -- Use lua_cache for caching
      lua_cache.cache_get(task, cache_key, cache_ctx, opts.timeout,
          function()
            -- Cache miss - query API
            query_search_api(task, domain, opts, function(api_results, d, api_err)
              if api_results then
                lua_cache.cache_set(task, cache_key, api_results, cache_ctx)
                domain_complete(d, api_results)
              else
                lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err)
                domain_complete(d, nil)
              end
            end, Np)
          end,
          function(_, err, data)
            -- Cache hit or after miss callback
            if data and type(data) == 'table' then
              lua_util.debugm(Np, task, "cache hit for domain %s", domain)
              domain_complete(domain, data)
              -- If no data and no error, the miss callback was already invoked
            elseif err then
              lua_util.debugm(Np, task, "cache error for domain %s: %s", domain, err)
              domain_complete(domain, nil)
            end
          end)
    else
      -- No Redis, query directly
      query_search_api(task, domain, opts, function(api_results, d, api_err)
        if not api_results then
          lua_util.debugm(Np, task, "search failed for domain %s: %s", d, api_err)
        end
        domain_complete(d, api_results)
      end, Np)
    end
  end
end

return M