File: devices_nvidia_linux.go

package info (click to toggle)
docker.io 28.5.2%2Bdfsg1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 69,048 kB
  • sloc: sh: 5,867; makefile: 863; ansic: 184; python: 162; asm: 159
file content (127 lines) | stat: -rw-r--r-- 3,782 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package daemon

import (
	"os"
	"os/exec"
	"strconv"
	"strings"

	"github.com/containerd/containerd/v2/contrib/nvidia"
	"github.com/docker/docker/daemon/internal/capabilities"
	"github.com/opencontainers/runtime-spec/specs-go"
	"github.com/pkg/errors"
)

// TODO: nvidia should not be hard-coded, and should be a device plugin instead on the daemon object.
// TODO: add list of device capabilities in daemon/node info

var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")

const (
	nvidiaHook                        = "nvidia-container-runtime-hook"
	amdContainerRuntimeExecutableName = "amd-container-runtime"
)

// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
var allNvidiaCaps = map[nvidia.Capability]struct{}{
	nvidia.Compute:  {},
	nvidia.Compat32: {},
	nvidia.Graphics: {},
	nvidia.Utility:  {},
	nvidia.Video:    {},
	nvidia.Display:  {},
}

func init() {
	// Register Nvidia driver if Nvidia helper binary is present.
	if _, err := exec.LookPath(nvidiaHook); err == nil {
		capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
		for c := range allNvidiaCaps {
			capset[string(c)] = struct{}{}
		}
		registerDeviceDriver("nvidia", &deviceDriver{
			capset:     capset,
			updateSpec: setNvidiaGPUs,
		})
		return
	}

	// Register AMD driver if AMD helper binary is present.
	if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
		registerDeviceDriver("amd", &deviceDriver{
			capset:     capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
			updateSpec: setAMDGPUs,
		})
		return
	}

	// No "gpu" capability
}

func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {
	req := dev.req
	if req.Count != 0 && len(req.DeviceIDs) > 0 {
		return errConflictCountDeviceIDs
	}

	switch {
	case len(req.DeviceIDs) > 0:
		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
	case req.Count > 0:
		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES="+countToDevices(req.Count))
	case req.Count < 0:
		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=all")
	case req.Count == 0:
		s.Process.Env = append(s.Process.Env, "NVIDIA_VISIBLE_DEVICES=void")
	}

	var nvidiaCaps []string
	// req.Capabilities contains device capabilities, some but not all are NVIDIA driver capabilities.
	for _, c := range dev.selectedCaps {
		nvcap := nvidia.Capability(c)
		if _, isNvidiaCap := allNvidiaCaps[nvcap]; isNvidiaCap {
			nvidiaCaps = append(nvidiaCaps, c)
			continue
		}
		// TODO: nvidia.WithRequiredCUDAVersion
		// for now we let the prestart hook verify cuda versions but errors are not pretty.
	}

	if nvidiaCaps != nil {
		s.Process.Env = append(s.Process.Env, "NVIDIA_DRIVER_CAPABILITIES="+strings.Join(nvidiaCaps, ","))
	}

	path, err := exec.LookPath(nvidiaHook)
	if err != nil {
		return err
	}

	if s.Hooks == nil {
		s.Hooks = &specs.Hooks{}
	}

	// This implementation uses prestart hooks, which are deprecated.
	// CreateRuntime is the closest equivalent, and executed in the same
	// locations as prestart-hooks, but depending on what these hooks do,
	// possibly one of the other hooks could be used instead (such as
	// CreateContainer or StartContainer).
	s.Hooks.Prestart = append(s.Hooks.Prestart, specs.Hook{ //nolint:staticcheck // FIXME(thaJeztah); replace prestart hook with a non-deprecated one.
		Path: path,
		Args: []string{
			nvidiaHook,
			"prestart",
		},
		Env: os.Environ(),
	})

	return nil
}

// countToDevices returns the list 0, 1, ... count-1 of deviceIDs.
func countToDevices(count int) string {
	devices := make([]string, count)
	for i := range devices {
		devices[i] = strconv.Itoa(i)
	}
	return strings.Join(devices, ",")
}