File: test_aggregate_udf.cpp

package info (click to toggle)
duckdb 1.5.1-3
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 299,196 kB
  • sloc: cpp: 865,414; ansic: 57,292; python: 18,871; sql: 12,663; lisp: 11,751; yacc: 7,412; lex: 1,682; sh: 747; makefile: 564
file content (131 lines) | stat: -rw-r--r-- 6,384 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/common/types/date.hpp"
#include "duckdb/common/types/time.hpp"
#include "duckdb/common/types/timestamp.hpp"
#include "udf_functions_to_test.hpp"

using namespace duckdb;
using namespace std;

TEST_CASE("Aggregate UDFs", "[udf_function]") {
	duckdb::unique_ptr<QueryResult> result;
	DuckDB db(nullptr);
	Connection con(db);
	con.EnableQueryVerification();

	SECTION("Testing a binary aggregate UDF using only template parameters") {
		// using DOUBLEs
		REQUIRE_NOTHROW(
		    con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<double>, double, double>("udf_avg_double"));

		con.Query("CREATE TABLE doubles (d DOUBLE)");
		con.Query("INSERT INTO doubles VALUES (1), (2), (3), (4), (5)");
		result = con.Query("SELECT udf_avg_double(d) FROM doubles");
		REQUIRE(CHECK_COLUMN(result, 0, {3.0}));

		// using INTEGERs
		REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>("udf_avg_int"));

		con.Query("CREATE TABLE integers (i INTEGER)");
		con.Query("INSERT INTO integers VALUES (1), (2), (3), (4), (5)");
		result = con.Query("SELECT udf_avg_int(i) FROM integers");
		REQUIRE(CHECK_COLUMN(result, 0, {3}));
	}

	SECTION("Testing a binary aggregate UDF using only template parameters") {
		// using DOUBLEs
		con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
		    "udf_covar_pop_double");

		result = con.Query("SELECT udf_covar_pop_double(3,3), udf_covar_pop_double(NULL,3), "
		                   "udf_covar_pop_double(3,NULL), udf_covar_pop_double(NULL,NULL)");
		REQUIRE(CHECK_COLUMN(result, 0, {0}));
		REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
		REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
		REQUIRE(CHECK_COLUMN(result, 3, {Value()}));

		// using INTEGERs
		con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, int, int, int>("udf_covar_pop_int");

		result = con.Query("SELECT udf_covar_pop_int(3,3), udf_covar_pop_int(NULL,3), udf_covar_pop_int(3,NULL), "
		                   "udf_covar_pop_int(NULL,NULL)");
		REQUIRE(CHECK_COLUMN(result, 0, {0}));
		REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
		REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
		REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
	}

	SECTION("Testing aggregate UDF with arguments") {
		REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
		    "udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));

		con.Query("CREATE TABLE integers (i INTEGER)");
		con.Query("INSERT INTO integers VALUES (1), (2), (3), (4), (5)");
		result = con.Query("SELECT udf_avg_int_args(i) FROM integers");
		REQUIRE(CHECK_COLUMN(result, 0, {3}));

		// using TIMEs to test disambiguation
		REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<dtime_t>, dtime_t, dtime_t>(
		    "udf_avg_time_args", LogicalType::TIME, LogicalType::TIME));
		con.Query("CREATE TABLE times (t TIME)");
		con.Query("INSERT INTO times VALUES ('01:00:00'), ('01:00:00'), ('01:00:00'), ('01:00:00'), ('01:00:00')");
		result = con.Query("SELECT udf_avg_time_args(t) FROM times");

		REQUIRE(CHECK_COLUMN(result, 0, {Time::FromString("01:00:00")}));

		// using DOUBLEs and a binary UDF
		con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
		    "udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE);

		result = con.Query("SELECT udf_covar_pop_double_args(3,3), udf_covar_pop_double_args(NULL,3), "
		                   "udf_covar_pop_double_args(3,NULL), udf_covar_pop_double_args(NULL,NULL)");
		REQUIRE(CHECK_COLUMN(result, 0, {0}));
		REQUIRE(CHECK_COLUMN(result, 1, {Value()}));
		REQUIRE(CHECK_COLUMN(result, 2, {Value()}));
		REQUIRE(CHECK_COLUMN(result, 3, {Value()}));
	}

	SECTION("Testing aggregate UDF with WRONG arguments") {
		// wrong return type
		REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, double, int>(
		    "udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
		REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
		    "udf_avg_int_args", LogicalType::DOUBLE, LogicalType::INTEGER));

		// wrong first argument
		REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, double>(
		    "udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
		REQUIRE_THROWS(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
		    "udf_avg_int_args", LogicalType::INTEGER, LogicalType::DOUBLE));

		// wrong first argument
		REQUIRE_THROWS(con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, int>(
		    "udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE));
		REQUIRE_THROWS(con.CreateAggregateFunction<UDFCovarPopOperation, udf_covar_state_t, double, double, double>(
		    "udf_covar_pop_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::INTEGER));
	}

	SECTION("Cheking if aggregate UDFs are temporary") {
		REQUIRE_NOTHROW(
		    con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<double>, double, double>("udf_avg_double"));
		REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>("udf_avg_int"));
		REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<int>, int, int>(
		    "udf_avg_int_args", LogicalType::INTEGER, LogicalType::INTEGER));
		REQUIRE_NOTHROW(con.CreateAggregateFunction<UDFAverageFunction, udf_avg_state_t<double>, double, double>(
		    "udf_avg_double_args", LogicalType::DOUBLE, LogicalType::DOUBLE));

		REQUIRE_NO_FAIL(con.Query("SELECT udf_avg_double(1)"));
		REQUIRE_NO_FAIL(con.Query("SELECT udf_avg_int(1)"));
		REQUIRE_NO_FAIL(con.Query("SELECT udf_avg_int_args(1)"));
		REQUIRE_NO_FAIL(con.Query("SELECT udf_avg_double_args(1)"));

		// Trying to use the aggregate UDFs with a different connection, it must fail
		Connection con_NEW(db);
		con_NEW.EnableQueryVerification();
		REQUIRE_FAIL(con_NEW.Query("SELECT udf_avg_double(1)"));
		REQUIRE_FAIL(con_NEW.Query("SELECT udf_avg_int(1)"));
		REQUIRE_FAIL(con_NEW.Query("SELECT udf_avg_int_args(1)"));
		REQUIRE_FAIL(con_NEW.Query("SELECT udf_avg_double_args(1)"));
	}
}