File: llm_context.lua

package info (click to toggle)
rspamd 3.14.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 35,064 kB
  • sloc: ansic: 247,728; cpp: 107,741; javascript: 31,385; perl: 3,089; asm: 2,512; pascal: 1,625; python: 1,510; sh: 589; sql: 313; makefile: 195; xml: 74
file content (533 lines) | stat: -rw-r--r-- 16,479 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
--[[
Context management for LLM-based spam detection

Provides:
  - fetch(task, redis_params, opts, callback, debug_module): load context JSON from Redis and format prompt snippet
  - update_after_classification(task, redis_params, opts, result, sel_part, debug_module): update context after LLM result

Opts (all optional, safe defaults applied):
  enabled: boolean
  level: 'user' | 'domain' | 'esld' (scope for context key)
  key_prefix: string (prefix before scope)
  key_suffix: string (suffix after identity)
  max_messages: number (sliding window size)
  message_ttl: seconds
  ttl: seconds (Redis key TTL)
  top_senders: number (how many to keep in top_senders)
  summary_max_chars: number (truncate stored text)
  flagged_phrases: array of strings (case-insensitive match)
  last_labels_count: number

debug_module: optional string, module name for debug logging (default: 'llm_context')
]]

local M = {}

local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
local rspamd_util = require "rspamd_util"
local llm_common = require "llm_common"

local EMPTY = {}

local DEFAULTS = {
  enabled = false,
  level = 'user',
  key_prefix = 'user',
  key_suffix = 'mail_context',
  max_messages = 40,
  min_messages = 5, -- minimum messages in context before injecting into prompt
  message_ttl = 14 * 24 * 3600,
  ttl = 30 * 24 * 3600,
  top_senders = 5,
  summary_max_chars = 512,
  flagged_phrases = {
    'reset your password',
    'click here to verify',
    'confirm your account',
    'urgent invoice',
    'wire transfer',
  },
  last_labels_count = 10,
}

local function to_seconds(v)
  if type(v) == 'number' then return v end
  return tonumber(v) or 0
end

local function get_domain_from_addr(addr)
  if not addr then return nil end
  return string.match(addr, '.*@(.+)')
end

-- Determine our user/domain - same identity for both incoming and outgoing mail
local function get_our_identity(task, scope)
  -- For outgoing mail: authenticated user or sender from local network
  -- For incoming mail: principal recipient
  local user = task:get_user()
  local ip = task:get_ip()
  local is_outgoing = user or (ip and ip:is_local())

  local identity
  if scope == 'user' then
    if is_outgoing then
      -- Outgoing: use sender (authenticated user or from address)
      identity = user or task:get_reply_sender()
      if not identity then
        local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
        identity = from
      end
    else
      -- Incoming: use recipient
      identity = task:get_principal_recipient()
    end
  elseif scope == 'domain' then
    if is_outgoing then
      -- Outgoing: domain of sender
      if user then
        identity = get_domain_from_addr(user)
      end
      if not identity then
        identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
      end
    else
      -- Incoming: domain of recipient
      local rcpt = task:get_principal_recipient()
      identity = get_domain_from_addr(rcpt)
    end
  elseif scope == 'esld' then
    if is_outgoing then
      -- Outgoing: eSLD of sender domain
      local d
      if user then
        d = get_domain_from_addr(user)
      end
      if not d then
        d = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
      end
      if d then identity = rspamd_util.get_tld(d) end
    else
      -- Incoming: eSLD of recipient domain
      local rcpt = task:get_principal_recipient()
      local d = get_domain_from_addr(rcpt)
      if d then
        identity = rspamd_util.get_tld(d)
      end
    end
  end

  return identity
end

local function compute_identity(task, opts, debug_module)
  local N = debug_module or 'llm_context'
  local scope = opts.level or DEFAULTS.level
  local identity = get_our_identity(task, scope)

  if not identity or identity == '' then
    return nil
  end

  -- Log direction for debugging
  local user = task:get_user()
  local ip = task:get_ip()
  local is_outgoing = user or (ip and ip:is_local())
  lua_util.debugm(N, task, 'computed identity for %s (%s): %s',
    scope, is_outgoing and 'outgoing' or 'incoming', tostring(identity))

  local key_prefix = opts.key_prefix or DEFAULTS.key_prefix
  local key_suffix = opts.key_suffix or DEFAULTS.key_suffix
  local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix)

  return {
    scope = scope,
    identity = identity,
    key = key,
  }
end

local function parse_json(data)
  if not data then return nil end
  -- Redis can return userdata nil or empty string
  if type(data) == 'userdata' then
    data = tostring(data)
  end
  if type(data) ~= 'string' or data == '' then
    return nil
  end
  local parser = ucl.parser()
  local ok, err = parser:parse_text(data)
  if not ok then return nil, err end
  return parser:get_object()
end

local function encode_json(obj)
  return ucl.to_format(obj, 'json-compact', true)
end

local function now()
  return os.time()
end

local function truncate_text(txt, limit)
  if not txt then return '' end
  if #txt <= limit then return txt end
  return txt:sub(1, limit)
end

local function has_flag(flags, flag_name)
  if type(flags) ~= 'table' then return false end
  for _, f in ipairs(flags) do
    if f == flag_name then return true end
  end
  return false
end

local function extract_keywords(text_part, limit)
  if not text_part then return {} end
  local words = text_part:get_words('full')
  if not words or #words == 0 then return {} end

  local counts = {}
  for _, w in ipairs(words) do
    local norm_word = w[2] or '' -- normalized
    local flags = w[4] or {}
    -- Skip stop words, too short, or non-text
    if not has_flag(flags, 'stop_word') and #norm_word > 2 and has_flag(flags, 'text') then
      counts[norm_word] = (counts[norm_word] or 0) + 1
    end
  end

  local arr = {}
  for word, cnt in pairs(counts) do
    table.insert(arr, { w = word, c = cnt })
  end
  table.sort(arr, function(a, b)
    if a.c == b.c then return a.w < b.w end
    return a.c > b.c
  end)

  local res = {}
  for i = 1, math.min(limit or 12, #arr) do
    table.insert(res, arr[i].w)
  end
  return res
end

local function safe_array(arr)
  if type(arr) ~= 'table' then return {} end
  return arr
end

local function build_message_summary(task, sel_part, opts)
  local model_cfg = { max_tokens = 256 }
  local content_tbl
  if sel_part then
    local itbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens })
    content_tbl = itbl
  else
    content_tbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens })
  end
  if type(content_tbl) ~= 'table' then
    return nil
  end
  local txt = content_tbl.text or ''
  local summary_max = opts.summary_max_chars or DEFAULTS.summary_max_chars
  local msg = {
    from = content_tbl.from or ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'],
    subject = content_tbl.subject or '',
    ts = now(),
    keywords = extract_keywords(sel_part, 12),
  }
  if txt and #txt > 0 then
    msg.text = truncate_text(txt, summary_max)
  end
  return msg
end

local function trim_messages(recent_messages, max_messages, min_ts)
  local res = {}
  for _, m in ipairs(recent_messages) do
    if not min_ts or (m.ts and m.ts >= min_ts) then
      table.insert(res, m)
    end
  end
  table.sort(res, function(a, b)
    local ta = a.ts or 0
    local tb = b.ts or 0
    return ta > tb
  end)
  while #res > max_messages do
    table.remove(res)
  end
  return res
end

local function recompute_top_senders(sender_counts, limit_n)
  local arr = {}
  for s, c in pairs(sender_counts or {}) do
    table.insert(arr, { s = s, c = c })
  end
  table.sort(arr, function(a, b)
    if a.c == b.c then return a.s < b.s end
    return a.c > b.c
  end)
  local res = {}
  for i = 1, math.min(limit_n, #arr) do
    table.insert(res, arr[i].s)
  end
  return res
end

local function ensure_defaults(ctx)
  if type(ctx) ~= 'table' then ctx = {} end
  ctx.recent_messages = safe_array(ctx.recent_messages)
  ctx.top_senders = safe_array(ctx.top_senders)
  ctx.flagged_phrases = safe_array(ctx.flagged_phrases)
  ctx.last_spam_labels = safe_array(ctx.last_spam_labels)
  ctx.sender_counts = ctx.sender_counts or {}
  return ctx
end

local function contains_ci(haystack, needle)
  if not haystack or not needle then return false end
  return string.find(string.lower(haystack), string.lower(needle), 1, true) ~= nil
end

local function update_flagged_phrases(ctx, text_part, opts)
  local phrases = opts.flagged_phrases or DEFAULTS.flagged_phrases
  if not text_part then return end
  local words = text_part:get_words('norm')
  if not words or #words == 0 then return end
  local text_lower = table.concat(words, ' ')
  for _, p in ipairs(phrases) do
    if contains_ci(text_lower, p) then
      local present = false
      for _, e in ipairs(ctx.flagged_phrases) do
        if string.lower(e) == string.lower(p) then
          present = true
          break
        end
      end
      if not present then
        table.insert(ctx.flagged_phrases, p)
      end
    end
  end
end

local function to_bullets_recent(recent_messages, limit_n)
  local lines = {}
  local n = math.min(limit_n, #recent_messages)
  for i = 1, n do
    local m = recent_messages[i]
    local from = m.from or m.sender or ''
    local subj = m.subject or ''
    table.insert(lines, string.format('- %s: %s', from, subj))
  end
  return table.concat(lines, '\n')
end

local function join_list(arr)
  if not arr or #arr == 0 then return '' end
  return table.concat(arr, ', ')
end

local function format_context_prompt(ctx, task)
  local bullets = to_bullets_recent(ctx.recent_messages or {}, 5)
  local top_senders = join_list(ctx.top_senders or {})
  local flagged = join_list(ctx.flagged_phrases or {})
  local spam_types = join_list(ctx.last_spam_labels or {})

  -- Check if current sender is known
  local sender_frequency = 'new'
  if task then
    local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
    if from and ctx.sender_counts and ctx.sender_counts[from] then
      local count = ctx.sender_counts[from]
      if count >= 10 then
        sender_frequency = 'frequent'
      elseif count >= 3 then
        sender_frequency = 'known'
      else
        sender_frequency = 'occasional'
      end
    end
  end

  local parts = {}
  table.insert(parts, 'User recent correspondence summary:')
  if bullets ~= '' then
    table.insert(parts, bullets)
  else
    table.insert(parts, '- (no recent messages)')
  end
  table.insert(parts, string.format('Top senders in mailbox: %s', top_senders))
  if flagged ~= '' then
    table.insert(parts, string.format('Recently flagged suspicious phrases: %s', flagged))
  end
  if spam_types ~= '' then
    table.insert(parts, string.format('Last detected spam types: %s', spam_types))
  end
  table.insert(parts, string.format('Current sender: %s', sender_frequency))

  return table.concat(parts, '\n')
end

function M.fetch(task, redis_params, opts, callback, debug_module)
  local N = debug_module or 'llm_context'
  opts = lua_util.override_defaults(DEFAULTS, opts or {})
  if not opts.enabled then
    callback(nil, nil, nil)
    return
  end
  if not redis_params then
    callback('no redis', nil, nil)
    return
  end

  local ident = compute_identity(task, opts, N)
  if not ident then
    lua_util.debugm(N, task, 'no identity computed, skipping context')
    callback('no identity', nil, nil)
    return
  end

  lua_util.debugm(N, task, 'fetching context for %s: %s',
    tostring(ident.scope), tostring(ident.identity))

  local function on_get(err, data)
    if err then
      rspamd_logger.errx(task, 'llm_context: get failed: %s', tostring(err))
      callback(err, nil, nil)
      return
    end
    local ctx
    if data then
      lua_util.debugm(N, task, 'got context data from redis, parsing')
      ctx = ensure_defaults(select(1, parse_json(data)) or {})
    else
      lua_util.debugm(N, task, 'no context data in redis, using empty')
      ctx = ensure_defaults({})
    end

    -- Check if context has enough messages for warm-up
    local min_msgs = opts.min_messages or DEFAULTS.min_messages
    local msg_count = #(ctx.recent_messages or {})
    if msg_count < min_msgs then
      lua_util.debugm(N, task, 'context has only %s messages (min: %s), not injecting into prompt',
        tostring(msg_count), tostring(min_msgs))
      callback(nil, ctx, nil) -- return ctx but no prompt snippet
      return
    end

    lua_util.debugm(N, task, 'context warm-up OK: %s messages, generating snippet',
      tostring(msg_count))
    local prompt_snippet = format_context_prompt(ctx, task)
    callback(nil, ctx, prompt_snippet)
  end

  local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
  if not ok then
    callback('request not scheduled', nil, nil)
  end
end

function M.update_after_classification(task, redis_params, opts, result, sel_part, debug_module)
  local N = debug_module or 'llm_context'
  opts = lua_util.override_defaults(DEFAULTS, opts or {})
  if not opts.enabled then return end
  if not redis_params then return end

  local ident = compute_identity(task, opts, N)
  if not ident then return end

  local function on_get(err, data)
    if err then
      rspamd_logger.errx(task, 'llm_context: get for update failed: %s', tostring(err))
      return
    end
    lua_util.debugm(N, task, 'updating context for %s: %s',
      tostring(ident.scope), tostring(ident.identity))
    local ctx = ensure_defaults(select(1, parse_json(data)) or {})

    local msg = build_message_summary(task, sel_part, opts)
    if msg then
      table.insert(ctx.recent_messages, 1, msg)
      local sender = msg.from or ''
      if sender ~= '' then
        ctx.sender_counts[sender] = (ctx.sender_counts[sender] or 0) + 1
      end
      update_flagged_phrases(ctx, sel_part, opts)
    end

    local min_ts = now() - to_seconds(opts.message_ttl)
    ctx.recent_messages = trim_messages(ctx.recent_messages, opts.max_messages, min_ts)
    ctx.top_senders = recompute_top_senders(ctx.sender_counts, opts.top_senders)

    local labels = {}
    if result then
      if result.categories and type(result.categories) == 'table' then
        for _, c in ipairs(result.categories) do table.insert(labels, tostring(c)) end
      end
      if result.probability then
        if result.probability > 0.5 then
          table.insert(labels, 'spam')
        else
          table.insert(labels, 'ham')
        end
      end
    end
    for _, l in ipairs(labels) do table.insert(ctx.last_spam_labels, 1, l) end
    while #ctx.last_spam_labels > opts.last_labels_count do table.remove(ctx.last_spam_labels) end

    ctx.updated_at = now()

    local payload = encode_json(ctx)
    local ttl = to_seconds(opts.ttl)
    local expire_at = now() + ttl

    -- Log what we're storing in context
    lua_util.debugm(N, task,
      'storing context for %s: %s messages, labels=%s, top_senders=%s, flagged=%s, payload_size=%s bytes, expiring at %s',
      tostring(ident.identity or '(none)'),
      tostring(#ctx.recent_messages),
      table.concat(ctx.last_spam_labels or {}, ','),
      table.concat(ctx.top_senders or {}, ','),
      table.concat(ctx.flagged_phrases or {}, ','),
      tostring(#payload),
      os.date('%Y-%m-%d %H:%M:%S', expire_at))

    if msg then
      lua_util.debugm(N, task,
        'added message: from=%s, subject=%s, keywords=%s',
        tostring(msg.from or '(none)'),
        tostring(msg.subject or '(none)'),
        table.concat(msg.keywords or {}, ','))
    end

    local function on_set(set_err)
      if set_err then
        rspamd_logger.errx(task, 'llm_context: set failed: %s', tostring(set_err))
      else
        lua_util.debugm(N, task, 'context saved to redis: key=%s, ttl=%s, expiring at %s',
          tostring(ident.key), tostring(ttl), os.date('%Y-%m-%d %H:%M:%S', expire_at))
      end
    end
    local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX',
      { ident.key, tostring(math.floor(ttl)), payload })
    if not ok then
      rspamd_logger.errx(task, 'llm_context: set request was not scheduled')
    end
  end

  local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
  if not ok then
    rspamd_logger.errx(task, 'llm_context: initial get request was not scheduled')
  end
end

return M