File: main.go

package info (click to toggle)
golang-mongodb-mongo-driver 1.17.1%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,988 kB
  • sloc: perl: 533; ansic: 491; python: 432; sh: 327; makefile: 174
file content (114 lines) | stat: -rw-r--r-- 3,808 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package main

import (
	"context"
	"fmt"
	"os"
	"strings"

	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/bson/primitive"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
)

var datakeyopts = map[string]primitive.M{
	"aws": bson.M{
		"region": "us-east-1",
		"key":    "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0",
	},
	"azure": bson.M{
		"keyVaultEndpoint": "",
		"keyName":          "",
	},
	"gcp": bson.M{
		"projectId": "devprod-drivers",
		"location":  "global",
		"keyRing":   "key-ring-csfle",
		"keyName":   "key-name-csfle",
	},
}

func main() {
	uri := os.Getenv("MONGODB_URI")
	provider := os.Getenv("PROVIDER")
	// expecterror is an expect error substring. Set to empty string to expect no error.
	expecterror := os.Getenv("EXPECT_ERROR")

	datakeyopt, validKmsProvider := datakeyopts[provider]
	ok := false
	switch {
	case uri == "":
		fmt.Println("ERROR: Please set required MONGODB_URI environment variable.")
	case provider == "":
		fmt.Println("ERROR: Please set required PROVIDER environment variable.")
	case !validKmsProvider:
		fmt.Println("ERROR: Unsupported PROVIDER value.")
	default:
		ok = true
	}
	if provider == "azure" {
		azureKmsKeyName := os.Getenv("AZUREKMS_KEY_NAME")
		azureKmsKeyVaultEndpoint := os.Getenv("AZUREKMS_KEY_VAULT_ENDPOINT")
		if azureKmsKeyName == "" {
			fmt.Println("ERROR: Please set required AZUREKMS_KEY_NAME environment variable.")
			ok = false
		}
		if azureKmsKeyVaultEndpoint == "" {
			fmt.Println("ERROR: Please set required AZUREKMS_KEY_VAULT_ENDPOINT environment variable.")
			ok = false
		}
		datakeyopts["azure"]["keyName"] = azureKmsKeyName
		datakeyopts["azure"]["keyVaultEndpoint"] = azureKmsKeyVaultEndpoint
	}
	if !ok {
		providers := make([]string, 0, len(datakeyopts))
		for p := range datakeyopts {
			providers = append(providers, p)
		}

		fmt.Println("The following environment variables are understood:")
		fmt.Println("- MONGODB_URI as a MongoDB URI. Example: 'mongodb://localhost:27017'")
		fmt.Println("- EXPECT_ERROR as an optional expected error substring.")
		fmt.Println("- PROVIDER as a KMS provider, which supports:", strings.Join(providers, ", "))
		fmt.Println("- AZUREKMS_KEY_NAME as the Azure key name. Required if PROVIDER=azure.")
		fmt.Println("- AZUREKMS_KEY_VAULT_ENDPOINT as the Azure key name. Required if PROVIDER=azure.")
		os.Exit(1)
	}

	cOpts := options.Client().ApplyURI(uri)
	keyVaultClient, err := mongo.Connect(context.Background(), cOpts)
	if err != nil {
		panic(fmt.Sprintf("Connect error: %v", err))
	}
	defer func() { _ = keyVaultClient.Disconnect(context.Background()) }()

	kmsProvidersMap := map[string]map[string]interface{}{
		provider: {},
	}
	ceOpts := options.ClientEncryption().SetKmsProviders(kmsProvidersMap).SetKeyVaultNamespace("keyvault.datakeys")
	ce, err := mongo.NewClientEncryption(keyVaultClient, ceOpts)
	if err != nil {
		panic(fmt.Sprintf("Error in NewClientEncryption: %v", err))
	}
	dkOpts := options.DataKey().SetMasterKey(datakeyopt)
	_, err = ce.CreateDataKey(context.Background(), provider, dkOpts)
	if expecterror == "" {
		if err != nil {
			panic(fmt.Sprintf("Expected success, but got error in CreateDataKey: %v", err))
		}
	} else {
		if err == nil {
			panic(fmt.Sprintf("Expected error message to contain %q, but got no error", expecterror))
		}
		if !strings.Contains(err.Error(), expecterror) {
			panic(fmt.Sprintf("Expected error message to contain %q, but got %q", expecterror, err.Error()))
		}
	}
}