1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
|
module Neighbor
module Utils
def self.validate_dimensions(value, type, expected, adapter)
dimensions = type == :sparsevec ? value.dimensions : value.size
dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)
if expected && dimensions != expected
"Expected #{expected} dimensions, not #{dimensions}"
end
end
def self.validate_finite(value, type)
case type
when :bit, :integer
true
when :sparsevec
value.values.all?(&:finite?)
else
value.all?(&:finite?)
end
end
def self.validate(value, dimensions:, type:, adapter:)
if (message = validate_dimensions(value, type, dimensions, adapter))
raise Error, message
end
if !validate_finite(value, type)
raise Error, "Values must be finite"
end
end
def self.normalize(value, column_info:)
return nil if value.nil?
raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec].include?(column_info&.type)
norm = Math.sqrt(value.sum { |v| v * v })
# store zero vector as all zeros
# since NaN makes the distance always 0
# could also throw error
norm > 0 ? value.map { |v| v / norm } : value
end
def self.array?(value)
!value.nil? && value.respond_to?(:to_a)
end
def self.adapter(model)
case model.connection_db_config.adapter
when /sqlite/i
:sqlite
when /mysql|trilogy/i
model.connection_pool.with_connection { |c| c.try(:mariadb?) } ? :mariadb : :mysql
else
:postgresql
end
end
def self.type(adapter, column_type)
case adapter
when :mysql
if column_type == :binary
:bit
else
column_type
end
else
column_type
end
end
def self.operator(adapter, column_type, distance)
case adapter
when :sqlite
case distance
when "euclidean"
"vec_distance_L2"
when "cosine"
"vec_distance_cosine"
when "taxicab"
"vec_distance_L1"
when "hamming"
"vec_distance_hamming"
end
when :mariadb
case column_type
when :vector
case distance
when "euclidean"
"VEC_DISTANCE_EUCLIDEAN"
when "cosine"
"VEC_DISTANCE_COSINE"
end
when :integer
case distance
when "hamming"
"BIT_COUNT"
end
else
raise ArgumentError, "Unsupported type: #{column_type}"
end
when :mysql
case column_type
when :vector
case distance
when "cosine"
"COSINE"
when "euclidean"
"EUCLIDEAN"
end
when :binary
case distance
when "hamming"
"BIT_COUNT"
end
else
raise ArgumentError, "Unsupported type: #{column_type}"
end
else
case column_type
when :bit
case distance
when "hamming"
"<~>"
when "jaccard"
"<%>"
when "hamming2"
"#"
end
when :vector, :halfvec, :sparsevec
case distance
when "inner_product"
"<#>"
when "cosine"
"<=>"
when "euclidean"
"<->"
when "taxicab"
"<+>"
end
when :cube
case distance
when "taxicab"
"<#>"
when "chebyshev"
"<=>"
when "euclidean", "cosine"
"<->"
end
else
raise ArgumentError, "Unsupported type: #{column_type}"
end
end
end
def self.order(adapter, type, operator, quoted_attribute, query)
case adapter
when :sqlite
case type
when :int8
"#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
when :bit
"#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
else
"#{operator}(#{quoted_attribute}, #{query})"
end
when :mariadb
if operator == "BIT_COUNT"
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
else
"#{operator}(#{quoted_attribute}, #{query})"
end
when :mysql
if operator == "BIT_COUNT"
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
elsif operator == "COSINE"
"DISTANCE(#{quoted_attribute}, #{query}, 'COSINE')"
else
"DISTANCE(#{quoted_attribute}, #{query}, 'EUCLIDEAN')"
end
else
if operator == "#"
"bit_count(#{quoted_attribute} # #{query})"
else
"#{quoted_attribute} #{operator} #{query}"
end
end
end
def self.normalize_required?(adapter, column_type)
case adapter
when :postgresql
column_type == :cube
else
false
end
end
end
end
|