File: model.rb

package info (click to toggle)
ruby-neighbor 0.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 204 kB
  • sloc: ruby: 840; makefile: 4
file content (163 lines) | stat: -rw-r--r-- 6,689 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
module Neighbor
  module Model
    def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
      if attribute_names.empty?
        raise ArgumentError, "has_neighbors requires an attribute name"
      end
      attribute_names.map!(&:to_sym)

      class_eval do
        @neighbor_attributes ||= {}

        if @neighbor_attributes.empty?
          def self.neighbor_attributes
            parent_attributes =
              if superclass.respond_to?(:neighbor_attributes)
                superclass.neighbor_attributes
              else
                {}
              end

            parent_attributes.merge(@neighbor_attributes || {})
          end
        end

        attribute_names.each do |attribute_name|
          raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
          @neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
        end

        if ActiveRecord::VERSION::STRING.to_f >= 7.2
          decorate_attributes(attribute_names) do |name, cast_type|
            Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
          end
        else
          attribute_names.each do |attribute_name|
            attribute attribute_name do |cast_type|
              Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
            end
          end
        end

        if normalize
          attribute_names.each do |attribute_name|
            normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
          end
        end

        return if @neighbor_attributes.size != attribute_names.size

        validate do
          adapter = Utils.adapter(self.class)

          self.class.neighbor_attributes.each do |k, v|
            value = read_attribute(k)
            next if value.nil?

            column_info = self.class.columns_hash[k.to_s]
            dimensions = v[:dimensions]
            dimensions ||= column_info&.limit unless column_info&.type == :binary
            type = v[:type] || Utils.type(adapter, column_info&.type)

            if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
              errors.add(k, "must have #{dimensions} dimensions")
            end
            if !Neighbor::Utils.validate_finite(value, type)
              errors.add(k, "must have finite values")
            end
          end
        end

        scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
          attribute_name = attribute_name.to_sym
          options = neighbor_attributes[attribute_name]
          raise ArgumentError, "Invalid attribute" unless options
          normalize = options[:normalize]
          dimensions = options[:dimensions]
          type = options[:type]

          return none if vector.nil?

          distance = distance.to_s

          column_info = columns_hash[attribute_name.to_s]
          column_type = column_info&.type

          adapter = Neighbor::Utils.adapter(klass)
          if type && adapter != :sqlite
            raise ArgumentError, "type only works with SQLite"
          end

          operator = Neighbor::Utils.operator(adapter, column_type, distance)
          raise ArgumentError, "Invalid distance: #{distance}" unless operator

          # ensure normalize set (can be true or false)
          normalize_required = Utils.normalize_required?(adapter, column_type)
          if distance == "cosine" && normalize_required && normalize.nil?
            raise Neighbor::Error, "Set normalize for cosine distance with cube"
          end

          column_attribute = klass.type_for_attribute(attribute_name)
          vector = column_attribute.cast(vector)
          dimensions ||= column_info&.limit unless column_info&.type == :binary
          Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
          vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize

          quoted_attribute = nil
          query = nil
          connection_pool.with_connection do |c|
            quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
            query = c.quote(column_attribute.serialize(vector))
          end

          if !precision.nil?
            if adapter != :postgresql || column_type != :vector
              raise ArgumentError, "Precision not supported for this type"
            end

            case precision.to_s
            when "half"
              cast_dimensions = dimensions || column_info&.limit
              raise ArgumentError, "Unknown dimensions" unless cast_dimensions
              quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
            else
              raise ArgumentError, "Invalid precision"
            end
          end

          order = Utils.order(adapter, type, operator, quoted_attribute, query)

          # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
          # with normalized vectors:
          # cosine similarity = 1 - (euclidean distance)**2 / 2
          # cosine distance = 1 - cosine similarity
          # this transformation doesn't change the order, so only needed for select
          neighbor_distance =
            if distance == "cosine" && normalize_required
              "POWER(#{order}, 2) / 2.0"
            elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
              "(#{order}) * -1"
            else
              order
            end

          # for select, use column_names instead of * to account for ignored columns
          select_columns = select_values.any? ? [] : column_names
          select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
            .where.not(attribute_name => nil)
            .reorder(Arel.sql(order))
        }

        def nearest_neighbors(attribute_name, **options)
          attribute_name = attribute_name.to_sym
          # important! check if neighbor attribute before accessing
          raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]

          self.class
            .where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
            .nearest_neighbors(attribute_name, self[attribute_name], **options)
        end
      end
    end
  end
end