1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
module Neighbor
module Model
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
if attribute_names.empty?
raise ArgumentError, "has_neighbors requires an attribute name"
end
attribute_names.map!(&:to_sym)
class_eval do
@neighbor_attributes ||= {}
if @neighbor_attributes.empty?
def self.neighbor_attributes
parent_attributes =
if superclass.respond_to?(:neighbor_attributes)
superclass.neighbor_attributes
else
{}
end
parent_attributes.merge(@neighbor_attributes || {})
end
end
attribute_names.each do |attribute_name|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
end
if ActiveRecord::VERSION::STRING.to_f >= 7.2
decorate_attributes(attribute_names) do |name, cast_type|
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
end
else
attribute_names.each do |attribute_name|
attribute attribute_name do |cast_type|
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
end
end
end
if normalize
attribute_names.each do |attribute_name|
normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
end
end
return if @neighbor_attributes.size != attribute_names.size
validate do
adapter = Utils.adapter(self.class)
self.class.neighbor_attributes.each do |k, v|
value = read_attribute(k)
next if value.nil?
column_info = self.class.columns_hash[k.to_s]
dimensions = v[:dimensions]
dimensions ||= column_info&.limit unless column_info&.type == :binary
type = v[:type] || Utils.type(adapter, column_info&.type)
if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
errors.add(k, "must have #{dimensions} dimensions")
end
if !Neighbor::Utils.validate_finite(value, type)
errors.add(k, "must have finite values")
end
end
end
scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
attribute_name = attribute_name.to_sym
options = neighbor_attributes[attribute_name]
raise ArgumentError, "Invalid attribute" unless options
normalize = options[:normalize]
dimensions = options[:dimensions]
type = options[:type]
return none if vector.nil?
distance = distance.to_s
column_info = columns_hash[attribute_name.to_s]
column_type = column_info&.type
adapter = Neighbor::Utils.adapter(klass)
if type && adapter != :sqlite
raise ArgumentError, "type only works with SQLite"
end
operator = Neighbor::Utils.operator(adapter, column_type, distance)
raise ArgumentError, "Invalid distance: #{distance}" unless operator
# ensure normalize set (can be true or false)
normalize_required = Utils.normalize_required?(adapter, column_type)
if distance == "cosine" && normalize_required && normalize.nil?
raise Neighbor::Error, "Set normalize for cosine distance with cube"
end
column_attribute = klass.type_for_attribute(attribute_name)
vector = column_attribute.cast(vector)
dimensions ||= column_info&.limit unless column_info&.type == :binary
Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
quoted_attribute = nil
query = nil
connection_pool.with_connection do |c|
quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
query = c.quote(column_attribute.serialize(vector))
end
if !precision.nil?
if adapter != :postgresql || column_type != :vector
raise ArgumentError, "Precision not supported for this type"
end
case precision.to_s
when "half"
cast_dimensions = dimensions || column_info&.limit
raise ArgumentError, "Unknown dimensions" unless cast_dimensions
quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
else
raise ArgumentError, "Invalid precision"
end
end
order = Utils.order(adapter, type, operator, quoted_attribute, query)
# https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
# with normalized vectors:
# cosine similarity = 1 - (euclidean distance)**2 / 2
# cosine distance = 1 - cosine similarity
# this transformation doesn't change the order, so only needed for select
neighbor_distance =
if distance == "cosine" && normalize_required
"POWER(#{order}, 2) / 2.0"
elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
"(#{order}) * -1"
else
order
end
# for select, use column_names instead of * to account for ignored columns
select_columns = select_values.any? ? [] : column_names
select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
.where.not(attribute_name => nil)
.reorder(Arel.sql(order))
}
def nearest_neighbors(attribute_name, **options)
attribute_name = attribute_name.to_sym
# important! check if neighbor attribute before accessing
raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
self.class
.where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
.nearest_neighbors(attribute_name, self[attribute_name], **options)
end
end
end
end
end
|