1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
|
--[[
Context management for LLM-based spam detection
Provides:
- fetch(task, redis_params, opts, callback, debug_module): load context JSON from Redis and format prompt snippet
- update_after_classification(task, redis_params, opts, result, sel_part, debug_module): update context after LLM result
Opts (all optional, safe defaults applied):
enabled: boolean
level: 'user' | 'domain' | 'esld' (scope for context key)
key_prefix: string (prefix before scope)
key_suffix: string (suffix after identity)
max_messages: number (sliding window size)
message_ttl: seconds
ttl: seconds (Redis key TTL)
top_senders: number (how many to keep in top_senders)
summary_max_chars: number (truncate stored text)
flagged_phrases: array of strings (case-insensitive match)
last_labels_count: number
debug_module: optional string, module name for debug logging (default: 'llm_context')
]]
local M = {}
local lua_redis = require "lua_redis"
local lua_util = require "lua_util"
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
local rspamd_util = require "rspamd_util"
local llm_common = require "llm_common"
local EMPTY = {}
local DEFAULTS = {
enabled = false,
level = 'user',
key_prefix = 'user',
key_suffix = 'mail_context',
max_messages = 40,
min_messages = 5, -- minimum messages in context before injecting into prompt
message_ttl = 14 * 24 * 3600,
ttl = 30 * 24 * 3600,
top_senders = 5,
summary_max_chars = 512,
flagged_phrases = {
'reset your password',
'click here to verify',
'confirm your account',
'urgent invoice',
'wire transfer',
},
last_labels_count = 10,
}
local function to_seconds(v)
if type(v) == 'number' then return v end
return tonumber(v) or 0
end
local function get_domain_from_addr(addr)
if not addr then return nil end
return string.match(addr, '.*@(.+)')
end
-- Determine our user/domain - same identity for both incoming and outgoing mail
local function get_our_identity(task, scope)
-- For outgoing mail: authenticated user or sender from local network
-- For incoming mail: principal recipient
local user = task:get_user()
local ip = task:get_ip()
local is_outgoing = user or (ip and ip:is_local())
local identity
if scope == 'user' then
if is_outgoing then
-- Outgoing: use sender (authenticated user or from address)
identity = user or task:get_reply_sender()
if not identity then
local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
identity = from
end
else
-- Incoming: use recipient
identity = task:get_principal_recipient()
end
elseif scope == 'domain' then
if is_outgoing then
-- Outgoing: domain of sender
if user then
identity = get_domain_from_addr(user)
end
if not identity then
identity = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
end
else
-- Incoming: domain of recipient
local rcpt = task:get_principal_recipient()
identity = get_domain_from_addr(rcpt)
end
elseif scope == 'esld' then
if is_outgoing then
-- Outgoing: eSLD of sender domain
local d
if user then
d = get_domain_from_addr(user)
end
if not d then
d = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['domain']
end
if d then identity = rspamd_util.get_tld(d) end
else
-- Incoming: eSLD of recipient domain
local rcpt = task:get_principal_recipient()
local d = get_domain_from_addr(rcpt)
if d then
identity = rspamd_util.get_tld(d)
end
end
end
return identity
end
local function compute_identity(task, opts, debug_module)
local N = debug_module or 'llm_context'
local scope = opts.level or DEFAULTS.level
local identity = get_our_identity(task, scope)
if not identity or identity == '' then
return nil
end
-- Log direction for debugging
local user = task:get_user()
local ip = task:get_ip()
local is_outgoing = user or (ip and ip:is_local())
lua_util.debugm(N, task, 'computed identity for %s (%s): %s',
scope, is_outgoing and 'outgoing' or 'incoming', tostring(identity))
local key_prefix = opts.key_prefix or DEFAULTS.key_prefix
local key_suffix = opts.key_suffix or DEFAULTS.key_suffix
local key = string.format('%s:%s:%s', key_prefix, identity, key_suffix)
return {
scope = scope,
identity = identity,
key = key,
}
end
local function parse_json(data)
if not data then return nil end
-- Redis can return userdata nil or empty string
if type(data) == 'userdata' then
data = tostring(data)
end
if type(data) ~= 'string' or data == '' then
return nil
end
local parser = ucl.parser()
local ok, err = parser:parse_text(data)
if not ok then return nil, err end
return parser:get_object()
end
local function encode_json(obj)
return ucl.to_format(obj, 'json-compact', true)
end
local function now()
return os.time()
end
local function truncate_text(txt, limit)
if not txt then return '' end
if #txt <= limit then return txt end
return txt:sub(1, limit)
end
local function has_flag(flags, flag_name)
if type(flags) ~= 'table' then return false end
for _, f in ipairs(flags) do
if f == flag_name then return true end
end
return false
end
local function extract_keywords(text_part, limit)
if not text_part then return {} end
local words = text_part:get_words('full')
if not words or #words == 0 then return {} end
local counts = {}
for _, w in ipairs(words) do
local norm_word = w[2] or '' -- normalized
local flags = w[4] or {}
-- Skip stop words, too short, or non-text
if not has_flag(flags, 'stop_word') and #norm_word > 2 and has_flag(flags, 'text') then
counts[norm_word] = (counts[norm_word] or 0) + 1
end
end
local arr = {}
for word, cnt in pairs(counts) do
table.insert(arr, { w = word, c = cnt })
end
table.sort(arr, function(a, b)
if a.c == b.c then return a.w < b.w end
return a.c > b.c
end)
local res = {}
for i = 1, math.min(limit or 12, #arr) do
table.insert(res, arr[i].w)
end
return res
end
local function safe_array(arr)
if type(arr) ~= 'table' then return {} end
return arr
end
local function build_message_summary(task, sel_part, opts)
local model_cfg = { max_tokens = 256 }
local content_tbl
if sel_part then
local itbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens })
content_tbl = itbl
else
content_tbl = llm_common.build_llm_input(task, { max_tokens = model_cfg.max_tokens })
end
if type(content_tbl) ~= 'table' then
return nil
end
local txt = content_tbl.text or ''
local summary_max = opts.summary_max_chars or DEFAULTS.summary_max_chars
local msg = {
from = content_tbl.from or ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr'],
subject = content_tbl.subject or '',
ts = now(),
keywords = extract_keywords(sel_part, 12),
}
if txt and #txt > 0 then
msg.text = truncate_text(txt, summary_max)
end
return msg
end
local function trim_messages(recent_messages, max_messages, min_ts)
local res = {}
for _, m in ipairs(recent_messages) do
if not min_ts or (m.ts and m.ts >= min_ts) then
table.insert(res, m)
end
end
table.sort(res, function(a, b)
local ta = a.ts or 0
local tb = b.ts or 0
return ta > tb
end)
while #res > max_messages do
table.remove(res)
end
return res
end
local function recompute_top_senders(sender_counts, limit_n)
local arr = {}
for s, c in pairs(sender_counts or {}) do
table.insert(arr, { s = s, c = c })
end
table.sort(arr, function(a, b)
if a.c == b.c then return a.s < b.s end
return a.c > b.c
end)
local res = {}
for i = 1, math.min(limit_n, #arr) do
table.insert(res, arr[i].s)
end
return res
end
local function ensure_defaults(ctx)
if type(ctx) ~= 'table' then ctx = {} end
ctx.recent_messages = safe_array(ctx.recent_messages)
ctx.top_senders = safe_array(ctx.top_senders)
ctx.flagged_phrases = safe_array(ctx.flagged_phrases)
ctx.last_spam_labels = safe_array(ctx.last_spam_labels)
ctx.sender_counts = ctx.sender_counts or {}
return ctx
end
local function contains_ci(haystack, needle)
if not haystack or not needle then return false end
return string.find(string.lower(haystack), string.lower(needle), 1, true) ~= nil
end
local function update_flagged_phrases(ctx, text_part, opts)
local phrases = opts.flagged_phrases or DEFAULTS.flagged_phrases
if not text_part then return end
local words = text_part:get_words('norm')
if not words or #words == 0 then return end
local text_lower = table.concat(words, ' ')
for _, p in ipairs(phrases) do
if contains_ci(text_lower, p) then
local present = false
for _, e in ipairs(ctx.flagged_phrases) do
if string.lower(e) == string.lower(p) then
present = true
break
end
end
if not present then
table.insert(ctx.flagged_phrases, p)
end
end
end
end
local function to_bullets_recent(recent_messages, limit_n)
local lines = {}
local n = math.min(limit_n, #recent_messages)
for i = 1, n do
local m = recent_messages[i]
local from = m.from or m.sender or ''
local subj = m.subject or ''
table.insert(lines, string.format('- %s: %s', from, subj))
end
return table.concat(lines, '\n')
end
local function join_list(arr)
if not arr or #arr == 0 then return '' end
return table.concat(arr, ', ')
end
local function format_context_prompt(ctx, task)
local bullets = to_bullets_recent(ctx.recent_messages or {}, 5)
local top_senders = join_list(ctx.top_senders or {})
local flagged = join_list(ctx.flagged_phrases or {})
local spam_types = join_list(ctx.last_spam_labels or {})
-- Check if current sender is known
local sender_frequency = 'new'
if task then
local from = ((task:get_from('smtp') or EMPTY)[1] or EMPTY)['addr']
if from and ctx.sender_counts and ctx.sender_counts[from] then
local count = ctx.sender_counts[from]
if count >= 10 then
sender_frequency = 'frequent'
elseif count >= 3 then
sender_frequency = 'known'
else
sender_frequency = 'occasional'
end
end
end
local parts = {}
table.insert(parts, 'User recent correspondence summary:')
if bullets ~= '' then
table.insert(parts, bullets)
else
table.insert(parts, '- (no recent messages)')
end
table.insert(parts, string.format('Top senders in mailbox: %s', top_senders))
if flagged ~= '' then
table.insert(parts, string.format('Recently flagged suspicious phrases: %s', flagged))
end
if spam_types ~= '' then
table.insert(parts, string.format('Last detected spam types: %s', spam_types))
end
table.insert(parts, string.format('Current sender: %s', sender_frequency))
return table.concat(parts, '\n')
end
function M.fetch(task, redis_params, opts, callback, debug_module)
local N = debug_module or 'llm_context'
opts = lua_util.override_defaults(DEFAULTS, opts or {})
if not opts.enabled then
callback(nil, nil, nil)
return
end
if not redis_params then
callback('no redis', nil, nil)
return
end
local ident = compute_identity(task, opts, N)
if not ident then
lua_util.debugm(N, task, 'no identity computed, skipping context')
callback('no identity', nil, nil)
return
end
lua_util.debugm(N, task, 'fetching context for %s: %s',
tostring(ident.scope), tostring(ident.identity))
local function on_get(err, data)
if err then
rspamd_logger.errx(task, 'llm_context: get failed: %s', tostring(err))
callback(err, nil, nil)
return
end
local ctx
if data then
lua_util.debugm(N, task, 'got context data from redis, parsing')
ctx = ensure_defaults(select(1, parse_json(data)) or {})
else
lua_util.debugm(N, task, 'no context data in redis, using empty')
ctx = ensure_defaults({})
end
-- Check if context has enough messages for warm-up
local min_msgs = opts.min_messages or DEFAULTS.min_messages
local msg_count = #(ctx.recent_messages or {})
if msg_count < min_msgs then
lua_util.debugm(N, task, 'context has only %s messages (min: %s), not injecting into prompt',
tostring(msg_count), tostring(min_msgs))
callback(nil, ctx, nil) -- return ctx but no prompt snippet
return
end
lua_util.debugm(N, task, 'context warm-up OK: %s messages, generating snippet',
tostring(msg_count))
local prompt_snippet = format_context_prompt(ctx, task)
callback(nil, ctx, prompt_snippet)
end
local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
if not ok then
callback('request not scheduled', nil, nil)
end
end
function M.update_after_classification(task, redis_params, opts, result, sel_part, debug_module)
local N = debug_module or 'llm_context'
opts = lua_util.override_defaults(DEFAULTS, opts or {})
if not opts.enabled then return end
if not redis_params then return end
local ident = compute_identity(task, opts, N)
if not ident then return end
local function on_get(err, data)
if err then
rspamd_logger.errx(task, 'llm_context: get for update failed: %s', tostring(err))
return
end
lua_util.debugm(N, task, 'updating context for %s: %s',
tostring(ident.scope), tostring(ident.identity))
local ctx = ensure_defaults(select(1, parse_json(data)) or {})
local msg = build_message_summary(task, sel_part, opts)
if msg then
table.insert(ctx.recent_messages, 1, msg)
local sender = msg.from or ''
if sender ~= '' then
ctx.sender_counts[sender] = (ctx.sender_counts[sender] or 0) + 1
end
update_flagged_phrases(ctx, sel_part, opts)
end
local min_ts = now() - to_seconds(opts.message_ttl)
ctx.recent_messages = trim_messages(ctx.recent_messages, opts.max_messages, min_ts)
ctx.top_senders = recompute_top_senders(ctx.sender_counts, opts.top_senders)
local labels = {}
if result then
if result.categories and type(result.categories) == 'table' then
for _, c in ipairs(result.categories) do table.insert(labels, tostring(c)) end
end
if result.probability then
if result.probability > 0.5 then
table.insert(labels, 'spam')
else
table.insert(labels, 'ham')
end
end
end
for _, l in ipairs(labels) do table.insert(ctx.last_spam_labels, 1, l) end
while #ctx.last_spam_labels > opts.last_labels_count do table.remove(ctx.last_spam_labels) end
ctx.updated_at = now()
local payload = encode_json(ctx)
local ttl = to_seconds(opts.ttl)
local expire_at = now() + ttl
-- Log what we're storing in context
lua_util.debugm(N, task,
'storing context for %s: %s messages, labels=%s, top_senders=%s, flagged=%s, payload_size=%s bytes, expiring at %s',
tostring(ident.identity or '(none)'),
tostring(#ctx.recent_messages),
table.concat(ctx.last_spam_labels or {}, ','),
table.concat(ctx.top_senders or {}, ','),
table.concat(ctx.flagged_phrases or {}, ','),
tostring(#payload),
os.date('%Y-%m-%d %H:%M:%S', expire_at))
if msg then
lua_util.debugm(N, task,
'added message: from=%s, subject=%s, keywords=%s',
tostring(msg.from or '(none)'),
tostring(msg.subject or '(none)'),
table.concat(msg.keywords or {}, ','))
end
local function on_set(set_err)
if set_err then
rspamd_logger.errx(task, 'llm_context: set failed: %s', tostring(set_err))
else
lua_util.debugm(N, task, 'context saved to redis: key=%s, ttl=%s, expiring at %s',
tostring(ident.key), tostring(ttl), os.date('%Y-%m-%d %H:%M:%S', expire_at))
end
end
local ok = lua_redis.redis_make_request(task, redis_params, ident.key, true, on_set, 'SETEX',
{ ident.key, tostring(math.floor(ttl)), payload })
if not ok then
rspamd_logger.errx(task, 'llm_context: set request was not scheduled')
end
end
local ok = lua_redis.redis_make_request(task, redis_params, ident.key, false, on_get, 'GET', { ident.key })
if not ok then
rspamd_logger.errx(task, 'llm_context: initial get request was not scheduled')
end
end
return M
|