File: context.go

package info (click to toggle)
golang-github-lk4d4-joincontext 0.0%2Bgit20171026.1724345-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, sid, trixie
  • size: 88 kB
  • sloc: makefile: 2
file content (102 lines) | stat: -rw-r--r-- 1,957 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// Package joincontext provides a way to combine two contexts.
// For example it might be useful for grpc server to cancel all handlers in
// addition to provided handler context.
package joincontext

import (
	"sync"
	"time"

	"golang.org/x/net/context"
)

type joinContext struct {
	mu   sync.Mutex
	ctx1 context.Context
	ctx2 context.Context
	done chan struct{}
	err  error
}

// Join returns new context which is child for two passed contexts.
// It starts new goroutine which tracks both contexts.
//
// Done() channel is closed when one of parents contexts is done.
//
// Deadline() returns earliest deadline between parent contexts.
//
// Err() returns error from first done parent context.
//
// Value(key) looks for key in parent contexts. First found is returned.
func Join(ctx1, ctx2 context.Context) (context.Context, context.CancelFunc) {
	c := &joinContext{ctx1: ctx1, ctx2: ctx2, done: make(chan struct{})}
	go c.run()
	return c, c.cancel
}

func (c *joinContext) Deadline() (deadline time.Time, ok bool) {
	d1, ok1 := c.ctx1.Deadline()
	if !ok1 {
		return c.ctx2.Deadline()
	}
	d2, ok2 := c.ctx2.Deadline()
	if !ok2 {
		return d1, true
	}

	if d2.Before(d1) {
		return d2, true
	}
	return d1, true
}

func (c *joinContext) Done() <-chan struct{} {
	return c.done
}

func (c *joinContext) Err() error {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.err
}

func (c *joinContext) Value(key interface{}) interface{} {
	v := c.ctx1.Value(key)
	if v == nil {
		v = c.ctx2.Value(key)
	}
	return v
}

func (c *joinContext) run() {
	var doneCtx context.Context
	select {
	case <-c.ctx1.Done():
		doneCtx = c.ctx1
	case <-c.ctx2.Done():
		doneCtx = c.ctx2
	case <-c.done:
		return
	}

	c.mu.Lock()
	if c.err != nil {
		c.mu.Unlock()
		return
	}
	c.err = doneCtx.Err()
	c.mu.Unlock()
	close(c.done)
}

func (c *joinContext) cancel() {
	c.mu.Lock()
	if c.err != nil {
		c.mu.Unlock()
		return
	}
	c.err = context.Canceled

	c.mu.Unlock()
	close(c.done)
}