File: udf.R

package info (click to toggle)
apache-arrow 23.0.1-1
  • links: PTS
  • area: main
  • in suites: sid
  • size: 76,220 kB
  • sloc: cpp: 654,608; python: 70,522; ruby: 45,964; ansic: 18,742; sh: 7,365; makefile: 669; javascript: 125; xml: 41
file content (194 lines) | stat: -rw-r--r-- 7,274 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

#' Register user-defined functions
#'
#' These functions support calling R code from query engine execution
#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
#' Use [register_scalar_function()] attach Arrow input and output types to an
#' R function and make it available for use in the dplyr interface and/or
#' [call_function()]. Scalar functions are currently the only type of
#' user-defined function supported. In Arrow, scalar functions must be
#' stateless and return output with the same shape (i.e., the same number
#' of rows) as the input.
#'
#' @param name The function name to be used in the dplyr bindings
#' @param in_type A [DataType] of the input type or a [schema()]
#'   for functions with more than one argument. This signature will be used
#'   to determine if this function is appropriate for a given set of arguments.
#'   If this function is appropriate for more than one signature, pass a
#'   `list()` of the above.
#' @param out_type A [DataType] of the output type or a function accepting
#'   a single argument (`types`), which is a `list()` of [DataType]s. If a
#'   function it must return a [DataType].
#' @param fun An R function or rlang-style lambda expression. The function
#'   will be called with a first argument `context` which is a `list()`
#'   with elements `batch_size` (the expected length of the output) and
#'   `output_type` (the required [DataType] of the output) that may be used
#'   to ensure that the output has the correct type and length. Subsequent
#'   arguments are passed by position as specified by `in_types`. If
#'   `auto_convert` is `TRUE`, subsequent arguments are converted to
#'   R vectors before being passed to `fun` and the output is automatically
#'   constructed with the expected output type via [as_arrow_array()].
#' @param auto_convert Use `TRUE` to convert inputs before passing to `fun`
#'   and construct an Array of the correct type from the output. Use this
#'   option to write functions of R objects as opposed to functions of
#'   Arrow R6 objects.
#'
#' @return `NULL`, invisibly
#' @export
#'
#' @examplesIf arrow_with_dataset() && identical(Sys.getenv("NOT_CRAN"), "true")
#' library(dplyr, warn.conflicts = FALSE)
#'
#' some_model <- lm(mpg ~ disp + cyl, data = mtcars)
#' register_scalar_function(
#'   "mtcars_predict_mpg",
#'   function(context, disp, cyl) {
#'     predict(some_model, newdata = data.frame(disp, cyl))
#'   },
#'   in_type = schema(disp = float64(), cyl = float64()),
#'   out_type = float64(),
#'   auto_convert = TRUE
#' )
#'
#' as_arrow_table(mtcars) |>
#'   transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) |>
#'   collect() |>
#'   head()
#'
register_scalar_function <- function(name, fun, in_type, out_type, auto_convert = FALSE) {
  assert_that(is.string(name))

  scalar_function <- arrow_scalar_function(
    fun,
    in_type,
    out_type,
    auto_convert = auto_convert
  )

  # register with Arrow C++ function registry (enables its use in
  # call_function() and Expression$create())
  RegisterScalarUDF(name, scalar_function)

  # register with dplyr binding (enables its use in mutate(), filter(), etc.)
  binding_fun <- function(...) Expression$create(name, ...)

  # inject the value of `name` into the expression to avoid saving this
  # execution environment in the binding, which eliminates a warning when the
  # same binding is registered twice
  body(binding_fun) <- expr_substitute(body(binding_fun), sym("name"), name)
  environment(binding_fun) <- asNamespace("arrow")

  register_binding(name, binding_fun)
  invisible(NULL)
}

arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) {
  assert_that(is.function(fun))

  # Create a small wrapper function that is easier to call from C++.
  # TODO(ARROW-17148): This wrapper could be implemented in C/C++ to
  # reduce evaluation overhead and generate prettier backtraces when
  # errors occur (probably using a similar approach to purrr).
  if (auto_convert) {
    wrapper_fun <- function(context, args) {
      args <- lapply(args, as.vector)
      result <- do.call(fun, c(list(context), args))
      as_arrow_array(result, type = context$output_type)
    }
  } else {
    wrapper_fun <- function(context, args) {
      do.call(fun, c(list(context), args))
    }
  }

  # in_type can be a list() if registering multiple kernels at once
  if (is.list(in_type)) {
    in_type <- lapply(in_type, in_type_as_schema)
  } else {
    in_type <- list(in_type_as_schema(in_type))
  }

  # out_type can be a list() if registering multiple kernels at once
  if (is.list(out_type)) {
    out_type <- lapply(out_type, out_type_as_function)
  } else {
    out_type <- list(out_type_as_function(out_type))
  }

  # recycle out_type (which is frequently length 1 even if multiple kernels
  # are being registered at once)
  out_type <- rep_len(out_type, length(in_type))

  # check n_kernels and number of args in fun
  n_kernels <- length(in_type)
  if (n_kernels == 0) {
    abort("Can't register user-defined scalar function with 0 kernels")
  }

  expected_n_args <- in_type[[1]]$num_fields + 1L
  fun_formals_have_dots <- any(names(formals(fun)) == "...")
  if (!fun_formals_have_dots && length(formals(fun)) != expected_n_args) {
    abort(
      sprintf(
        paste0(
          "Expected `fun` to accept %d argument(s)\n",
          "but found a function that accepts %d argument(s)\n",
          "Did you forget to include `context` as the first argument?"
        ),
        expected_n_args,
        length(formals(fun))
      )
    )
  }

  structure(
    list(
      wrapper_fun = wrapper_fun,
      in_type = in_type,
      out_type = out_type
    ),
    class = "arrow_scalar_function"
  )
}

# This function sanitizes the in_type argument for arrow_scalar_function(),
# which can be a data type (e.g., int32()), a field for a unary function
# or a schema() for functions accepting more than one argument. C++ expects
# a schema().
in_type_as_schema <- function(x) {
  if (inherits(x, "Field")) {
    schema(x)
  } else if (inherits(x, "DataType")) {
    schema(field("", x))
  } else {
    as_schema(x)
  }
}

# This function sanitizes the out_type argument for arrow_scalar_function(),
# which can be a data type (e.g., int32()) or a function of the input types.
# C++ currently expects a function.
out_type_as_function <- function(x) {
  if (is.function(x)) {
    x
  } else {
    x <- as_data_type(x)
    function(types) x
  }
}