1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
|
local api = vim.api
local M = {}
---@class Range2
---@inlinedoc
---@field [1] integer start row
---@field [2] integer end row
---@class Range4
---@inlinedoc
---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer end row
---@field [4] integer end column
---@class Range6
---@inlinedoc
---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer start bytes
---@field [4] integer end row
---@field [5] integer end column
---@field [6] integer end bytes
---@alias Range Range2|Range4|Range6
---@private
---@param a_row integer
---@param a_col integer
---@param b_row integer
---@param b_col integer
---@return integer
--- 1: a > b
--- 0: a == b
--- -1: a < b
local function cmp_pos(a_row, a_col, b_row, b_col)
if a_row == b_row then
if a_col > b_col then
return 1
elseif a_col < b_col then
return -1
else
return 0
end
elseif a_row > b_row then
return 1
end
return -1
end
M.cmp_pos = {
lt = function(...)
return cmp_pos(...) == -1
end,
le = function(...)
return cmp_pos(...) ~= 1
end,
gt = function(...)
return cmp_pos(...) == 1
end,
ge = function(...)
return cmp_pos(...) ~= -1
end,
eq = function(...)
return cmp_pos(...) == 0
end,
ne = function(...)
return cmp_pos(...) ~= 0
end,
}
setmetatable(M.cmp_pos, { __call = cmp_pos })
---@private
---Check if a variable is a valid range object
---@param r any
---@return boolean
function M.validate(r)
if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
return false
end
for _, e in
ipairs(r --[[@as any[] ]])
do
if type(e) ~= 'number' then
return false
end
end
return true
end
---@private
---@param r1 Range
---@param r2 Range
---@return boolean
function M.intercepts(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
-- r1 is above r2
if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
return false
end
-- r1 is below r2
if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
return false
end
return true
end
---@private
---@param r Range
---@return integer, integer, integer, integer
function M.unpack4(r)
if #r == 2 then
return r[1], 0, r[2], 0
end
local off_1 = #r == 6 and 1 or 0
return r[1], r[2], r[3 + off_1], r[4 + off_1]
end
---@private
---@param r Range6
---@return integer, integer, integer, integer, integer, integer
function M.unpack6(r)
return r[1], r[2], r[3], r[4], r[5], r[6]
end
---@private
---@param r1 Range
---@param r2 Range
---@return boolean whether r1 contains r2
function M.contains(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
-- start doesn't fit
if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
return false
end
-- end doesn't fit
if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
return false
end
return true
end
--- @private
--- @param source integer|string
--- @param index integer
--- @return integer
local function get_offset(source, index)
if index == 0 then
return 0
end
if type(source) == 'number' then
return api.nvim_buf_get_offset(source, index)
end
local byte = 0
local next_offset = source:gmatch('()\n')
local line = 1
while line <= index do
byte = next_offset() --[[@as integer]]
line = line + 1
end
return byte
end
---@private
---@param source integer|string
---@param range Range
---@return Range6
function M.add_bytes(source, range)
if type(range) == 'table' and #range == 6 then
return range --[[@as Range6]]
end
local start_row, start_col, end_row, end_col = M.unpack4(range)
-- TODO(vigoux): proper byte computation here, and account for EOL ?
local start_byte = get_offset(source, start_row) + start_col
local end_byte = get_offset(source, end_row) + end_col
return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end
return M
|