1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#' Register user-defined functions
#'
#' These functions support calling R code from query engine execution
#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]).
#' Use [register_scalar_function()] attach Arrow input and output types to an
#' R function and make it available for use in the dplyr interface and/or
#' [call_function()]. Scalar functions are currently the only type of
#' user-defined function supported. In Arrow, scalar functions must be
#' stateless and return output with the same shape (i.e., the same number
#' of rows) as the input.
#'
#' @param name The function name to be used in the dplyr bindings
#' @param in_type A [DataType] of the input type or a [schema()]
#' for functions with more than one argument. This signature will be used
#' to determine if this function is appropriate for a given set of arguments.
#' If this function is appropriate for more than one signature, pass a
#' `list()` of the above.
#' @param out_type A [DataType] of the output type or a function accepting
#' a single argument (`types`), which is a `list()` of [DataType]s. If a
#' function it must return a [DataType].
#' @param fun An R function or rlang-style lambda expression. The function
#' will be called with a first argument `context` which is a `list()`
#' with elements `batch_size` (the expected length of the output) and
#' `output_type` (the required [DataType] of the output) that may be used
#' to ensure that the output has the correct type and length. Subsequent
#' arguments are passed by position as specified by `in_types`. If
#' `auto_convert` is `TRUE`, subsequent arguments are converted to
#' R vectors before being passed to `fun` and the output is automatically
#' constructed with the expected output type via [as_arrow_array()].
#' @param auto_convert Use `TRUE` to convert inputs before passing to `fun`
#' and construct an Array of the correct type from the output. Use this
#' option to write functions of R objects as opposed to functions of
#' Arrow R6 objects.
#'
#' @return `NULL`, invisibly
#' @export
#'
#' @examplesIf arrow_with_dataset() && identical(Sys.getenv("NOT_CRAN"), "true")
#' library(dplyr, warn.conflicts = FALSE)
#'
#' some_model <- lm(mpg ~ disp + cyl, data = mtcars)
#' register_scalar_function(
#' "mtcars_predict_mpg",
#' function(context, disp, cyl) {
#' predict(some_model, newdata = data.frame(disp, cyl))
#' },
#' in_type = schema(disp = float64(), cyl = float64()),
#' out_type = float64(),
#' auto_convert = TRUE
#' )
#'
#' as_arrow_table(mtcars) |>
#' transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) |>
#' collect() |>
#' head()
#'
register_scalar_function <- function(name, fun, in_type, out_type, auto_convert = FALSE) {
assert_that(is.string(name))
scalar_function <- arrow_scalar_function(
fun,
in_type,
out_type,
auto_convert = auto_convert
)
# register with Arrow C++ function registry (enables its use in
# call_function() and Expression$create())
RegisterScalarUDF(name, scalar_function)
# register with dplyr binding (enables its use in mutate(), filter(), etc.)
binding_fun <- function(...) Expression$create(name, ...)
# inject the value of `name` into the expression to avoid saving this
# execution environment in the binding, which eliminates a warning when the
# same binding is registered twice
body(binding_fun) <- expr_substitute(body(binding_fun), sym("name"), name)
environment(binding_fun) <- asNamespace("arrow")
register_binding(name, binding_fun)
invisible(NULL)
}
arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) {
assert_that(is.function(fun))
# Create a small wrapper function that is easier to call from C++.
# TODO(ARROW-17148): This wrapper could be implemented in C/C++ to
# reduce evaluation overhead and generate prettier backtraces when
# errors occur (probably using a similar approach to purrr).
if (auto_convert) {
wrapper_fun <- function(context, args) {
args <- lapply(args, as.vector)
result <- do.call(fun, c(list(context), args))
as_arrow_array(result, type = context$output_type)
}
} else {
wrapper_fun <- function(context, args) {
do.call(fun, c(list(context), args))
}
}
# in_type can be a list() if registering multiple kernels at once
if (is.list(in_type)) {
in_type <- lapply(in_type, in_type_as_schema)
} else {
in_type <- list(in_type_as_schema(in_type))
}
# out_type can be a list() if registering multiple kernels at once
if (is.list(out_type)) {
out_type <- lapply(out_type, out_type_as_function)
} else {
out_type <- list(out_type_as_function(out_type))
}
# recycle out_type (which is frequently length 1 even if multiple kernels
# are being registered at once)
out_type <- rep_len(out_type, length(in_type))
# check n_kernels and number of args in fun
n_kernels <- length(in_type)
if (n_kernels == 0) {
abort("Can't register user-defined scalar function with 0 kernels")
}
expected_n_args <- in_type[[1]]$num_fields + 1L
fun_formals_have_dots <- any(names(formals(fun)) == "...")
if (!fun_formals_have_dots && length(formals(fun)) != expected_n_args) {
abort(
sprintf(
paste0(
"Expected `fun` to accept %d argument(s)\n",
"but found a function that accepts %d argument(s)\n",
"Did you forget to include `context` as the first argument?"
),
expected_n_args,
length(formals(fun))
)
)
}
structure(
list(
wrapper_fun = wrapper_fun,
in_type = in_type,
out_type = out_type
),
class = "arrow_scalar_function"
)
}
# This function sanitizes the in_type argument for arrow_scalar_function(),
# which can be a data type (e.g., int32()), a field for a unary function
# or a schema() for functions accepting more than one argument. C++ expects
# a schema().
in_type_as_schema <- function(x) {
if (inherits(x, "Field")) {
schema(x)
} else if (inherits(x, "DataType")) {
schema(field("", x))
} else {
as_schema(x)
}
}
# This function sanitizes the out_type argument for arrow_scalar_function(),
# which can be a data type (e.g., int32()) or a function of the input types.
# C++ currently expects a function.
out_type_as_function <- function(x) {
if (is.function(x)) {
x
} else {
x <- as_data_type(x)
function(types) x
}
}
|