File: windows.go

package info (click to toggle)
golang-github-azuread-microsoft-authentication-extensions-for-go 0.0~git20231002.7e3b8e2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 260 kB
  • sloc: makefile: 4
file content (115 lines) | stat: -rw-r--r-- 2,788 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

//go:build windows
// +build windows

package accessor

import (
	"context"
	"errors"
	"os"
	"path/filepath"
	"sync"
	"unsafe"

	"golang.org/x/sys/windows"
)

// Storage stores data in a file encrypted by the Windows data protection API.
type Storage struct {
	m *sync.RWMutex
	p string
}

// New is the constructor for Storage. "p" is the path to the file in which to store data.
func New(p string) (*Storage, error) {
	return &Storage{m: &sync.RWMutex{}, p: p}, nil
}

// Delete deletes the file, if it exists.
func (s *Storage) Delete(context.Context) error {
	s.m.Lock()
	defer s.m.Unlock()
	err := os.Remove(s.p)
	if errors.Is(err, os.ErrNotExist) {
		return nil
	}
	return err
}

// Read returns data from the file. If the file doesn't exist, Read returns a nil slice and error.
func (s *Storage) Read(context.Context) ([]byte, error) {
	s.m.RLock()
	defer s.m.RUnlock()

	data, err := os.ReadFile(s.p)
	if errors.Is(err, os.ErrNotExist) {
		return nil, nil
	}
	if err != nil {
		return nil, err
	}
	if len(data) > 0 {
		data, err = dpapi(decrypt, data)
	}
	return data, err
}

// Write stores data in the file, creating the file if it doesn't exist.
func (s *Storage) Write(_ context.Context, data []byte) error {
	s.m.Lock()
	defer s.m.Unlock()

	data, err := dpapi(encrypt, data)
	if err != nil {
		return err
	}
	err = os.WriteFile(s.p, data, 0600)
	if errors.Is(err, os.ErrNotExist) {
		dir := filepath.Dir(s.p)
		if err = os.MkdirAll(dir, 0700); err == nil {
			err = os.WriteFile(s.p, data, 0600)
		}
	}
	return err
}

type operation int

const (
	decrypt operation = iota
	encrypt
)

func dpapi(op operation, data []byte) (result []byte, err error) {
	out := windows.DataBlob{}
	defer func() {
		if out.Data != nil {
			_, e := windows.LocalFree(windows.Handle(unsafe.Pointer(out.Data)))
			// prefer returning DPAPI errors because they're more interesting than LocalFree errors
			if e != nil && err == nil {
				err = e
			}
		}
	}()
	in := windows.DataBlob{Data: &data[0], Size: uint32(len(data))}
	switch op {
	case decrypt:
		// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptunprotectdata
		err = windows.CryptUnprotectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
	case encrypt:
		// https://learn.microsoft.com/windows/win32/api/dpapi/nf-dpapi-cryptprotectdata
		err = windows.CryptProtectData(&in, nil, nil, 0, nil, windows.CRYPTPROTECT_UI_FORBIDDEN, &out)
	default:
		err = errors.New("invalid operation")
	}
	if err == nil {
		result = make([]byte, out.Size)
		copy(result, unsafe.Slice(out.Data, out.Size))
	}
	return result, err
}

var _ Accessor = (*Storage)(nil)