File: verb-uncount.R

package info (click to toggle)
r-cran-dbplyr 2.3.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 2,376 kB
  • sloc: sh: 13; makefile: 2
file content (93 lines) | stat: -rw-r--r-- 2,891 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#' "Uncount" a database table
#'
#' This is a method for the tidyr `uncount()` generic. It uses a temporary
#' table, so your database user needs permissions to create one.
#'
#' @inheritParams arrange.tbl_lazy
#' @param weights A vector of weights. Evaluated in the context of `data`;
#'   supports quasiquotation.
#' @param .id Supply a string to create a new variable which gives a unique
#'   identifier for each created row.
#' @param .remove If `TRUE`, and `weights` is the name of a column in `data`,
#'   then this column is removed.
#' @export
#' @examples
#' df <- memdb_frame(x = c("a", "b"), n = c(1, 2))
#' dbplyr_uncount(df, n)
#' dbplyr_uncount(df, n, .id = "id")
#'
#' # You can also use constants
#' dbplyr_uncount(df, 2)
#'
#' # Or expressions
#' dbplyr_uncount(df, 2 / n)
dbplyr_uncount <- function(data, weights, .remove = TRUE, .id = NULL) {
  # Overview of the strategy:
  # 1. calculate `n_max`, the max weight
  # 2. create a temporary db table with ids from 1 to `n_max`
  # 3. duplicate data by inner joining with this temporary table under the
  #    condition that data$weight <= tmp$id
  #
  # An alternative approach would be to use a recursive CTE but this requires
  # more code and testing across backends.
  # See https://stackoverflow.com/questions/33327837/repeat-rows-n-times-according-to-column-value


  weights_quo <- enquo(weights)
  weights_is_col <- quo_is_symbol(weights_quo) &&
    quo_name(weights_quo) %in% colnames(data)

  if (weights_is_col) {
    weights_col <- quo_name(weights_quo)
  } else {
    weights_col <- "..dbplyr_weight_col"
    data <- mutate(
      data,
      !!weights_col := !!weights_quo,
    )
  }

  n_max <- pull(summarise(ungroup(data), max(!!sym(weights_col), na.rm = TRUE)))
  n_max <- vctrs::vec_cast(n_max, integer(), x_arg = "weights")

  if (is_null(.id)) {
    row_id_col <- "..dbplyr_row_id"
  } else {
    row_id_col <- .id
  }
  tbl_name <- unclass(query_name(data)) %||% "LHS"
  sql_on_expr <- expr(RHS[[!!row_id_col]] <= (!!sym(tbl_name))[[!!weights_col]])

  con <- remote_con(data)
  # FIXME: Use generate_series() if available on database
  helper_table <- copy_inline(con, tibble(!!row_id_col := seq2(1, n_max)))

  data_uncounted <- inner_join(
    data,
    helper_table,
    by = character(),
    sql_on = translate_sql(!!sql_on_expr, con = con),
    x_as = tbl_name,
    y_as = "RHS"
  )

  cols_to_remove <- character()
  if (is_null(.id)) {
    cols_to_remove <- c(cols_to_remove, row_id_col)
  }

  grps <- group_vars(data_uncounted)
  if (!weights_is_col || (weights_is_col && .remove)) {
    cols_to_remove <- c(cols_to_remove, weights_col)
    # need to regroup to be able to remove weights_col
    grps <- setdiff(grps, weights_col)
  }


  data_uncounted %>%
    ungroup() %>%
    select(-all_of(cols_to_remove)) %>%
    group_by(!!!syms(grps))
}

globalVariables(c("RHS", "LHS", "all_of"))