File: cache.go

package info (click to toggle)
golang-google-api 0.214.0-2
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 317,208 kB
  • sloc: sh: 211; makefile: 26
file content (123 lines) | stat: -rw-r--r-- 2,762 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Copyright 2020 Google LLC.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package idtoken

import (
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"strconv"
	"strings"
	"sync"
	"time"
)

type cachingClient struct {
	client *http.Client

	// clock optionally specifies a func to return the current time.
	// If nil, time.Now is used.
	clock func() time.Time

	mu    sync.Mutex
	certs map[string]*cachedResponse
}

func newCachingClient(client *http.Client) *cachingClient {
	return &cachingClient{
		client: client,
		certs:  make(map[string]*cachedResponse, 2),
	}
}

type cachedResponse struct {
	resp *certResponse
	exp  time.Time
}

func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
	if response, ok := c.get(url); ok {
		return response, nil
	}
	req, err := http.NewRequest(http.MethodGet, url, nil)
	if err != nil {
		return nil, err
	}
	req = req.WithContext(ctx)
	resp, err := c.client.Do(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
	}

	certResp := &certResponse{}
	if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
		return nil, err

	}
	c.set(url, certResp, resp.Header)
	return certResp, nil
}

func (c *cachingClient) now() time.Time {
	if c.clock != nil {
		return c.clock()
	}
	return time.Now()
}

func (c *cachingClient) get(url string) (*certResponse, bool) {
	c.mu.Lock()
	defer c.mu.Unlock()
	cachedResp, ok := c.certs[url]
	if !ok {
		return nil, false
	}
	if c.now().After(cachedResp.exp) {
		return nil, false
	}
	return cachedResp.resp, true
}

func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
	exp := c.calculateExpireTime(headers)
	c.mu.Lock()
	c.certs[url] = &cachedResponse{resp: resp, exp: exp}
	c.mu.Unlock()
}

// calculateExpireTime will determine the expire time for the cache based on
// HTTP headers. If there is any difficulty reading the headers the fallback is
// to set the cache to expire now.
func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
	var maxAge int
	cc := strings.Split(headers.Get("cache-control"), ",")
	for _, v := range cc {
		if strings.Contains(v, "max-age") {
			ss := strings.Split(v, "=")
			if len(ss) < 2 {
				return c.now()
			}
			ma, err := strconv.Atoi(ss[1])
			if err != nil {
				return c.now()
			}
			maxAge = ma
		}
	}
	a := headers.Get("age")
	if a == "" {
		return c.now().Add(time.Duration(maxAge) * time.Second)
	}
	age, err := strconv.Atoi(a)
	if err != nil {
		return c.now()
	}
	return c.now().Add(time.Duration(maxAge-age) * time.Second)
}