@@ -23,34 +23,114 @@ import (
2323 "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2424)
2525
26+ // A RuntimeMode is used to select a specific mode of operation for the NVIDIA Container Runtime.
27+ type RuntimeMode string
28+
29+ const (
30+ // In LegacyRuntimeMode the nvidia-container-runtime injects the
31+ // nvidia-container-runtime-hook as a prestart hook into the incoming
32+ // container config. This hook invokes the nvidia-container-cli to perform
33+ // the required modifications to the container.
34+ LegacyRuntimeMode = RuntimeMode ("legacy" )
35+ // In CSVRuntimeMode the nvidia-container-runtime processes a set of CSV
36+ // files to determine which container modification are required. The
37+ // contents of these CSV files are used to generate an in-memory CDI
38+ // specification which is used to modify the container config.
39+ CSVRuntimeMode = RuntimeMode ("csv" )
40+ // In CDIRuntimeMode the nvidia-container-runtime applies the modifications
41+ // to the container config required for the requested CDI devices in the
42+ // same way that other CDI clients would.
43+ CDIRuntimeMode = RuntimeMode ("cdi" )
44+ // In JitCDIRuntimeMode the nvidia-container-runtime generates in-memory CDI
45+ // specifications for requested NVIDIA devices.
46+ JitCDIRuntimeMode = RuntimeMode ("jit-cdi" )
47+ )
48+
49+ type RuntimeModeResolver interface {
50+ ResolveRuntimeMode (string ) RuntimeMode
51+ }
52+
53+ type modeResolver struct {
54+ logger logger.Interface
55+ // TODO: This only needs to consider the requested devices.
56+ image * image.CUDA
57+ propertyExtractor info.PropertyExtractor
58+ defaultMode RuntimeMode
59+ }
60+
61+ type Option func (* modeResolver )
62+
63+ func WithDefaultMode (defaultMode RuntimeMode ) Option {
64+ return func (mr * modeResolver ) {
65+ mr .defaultMode = defaultMode
66+ }
67+ }
68+
69+ func WithLogger (logger logger.Interface ) Option {
70+ return func (mr * modeResolver ) {
71+ mr .logger = logger
72+ }
73+ }
74+
75+ func WithImage (image * image.CUDA ) Option {
76+ return func (mr * modeResolver ) {
77+ mr .image = image
78+ }
79+ }
80+
81+ func WithPropertyExtractor (propertyExtractor info.PropertyExtractor ) Option {
82+ return func (mr * modeResolver ) {
83+ mr .propertyExtractor = propertyExtractor
84+ }
85+ }
86+
87+ func NewRuntimeModeResolver (opts ... Option ) RuntimeModeResolver {
88+ r := & modeResolver {
89+ defaultMode : JitCDIRuntimeMode ,
90+ }
91+ for _ , opt := range opts {
92+ opt (r )
93+ }
94+ if r .logger == nil {
95+ r .logger = & logger.NullLogger {}
96+ }
97+
98+ return r
99+ }
100+
26101// ResolveAutoMode determines the correct mode for the platform if set to "auto"
27- func ResolveAutoMode (logger logger.Interface , mode string , image image.CUDA ) (rmode string ) {
28- return resolveMode (logger , mode , image , nil )
102+ func ResolveAutoMode (logger logger.Interface , mode string , image image.CUDA ) (rmode RuntimeMode ) {
103+ r := modeResolver {
104+ logger : logger ,
105+ image : & image ,
106+ propertyExtractor : nil ,
107+ }
108+ return r .ResolveRuntimeMode (mode )
29109}
30110
31- func resolveMode ( logger logger. Interface , mode string , image image. CUDA , propertyExtractor info. PropertyExtractor ) (rmode string ) {
111+ func ( m * modeResolver ) ResolveRuntimeMode ( mode string ) (rmode RuntimeMode ) {
32112 if mode != "auto" {
33- logger .Infof ("Using requested mode '%s'" , mode )
34- return mode
113+ m . logger .Infof ("Using requested mode '%s'" , mode )
114+ return RuntimeMode ( mode )
35115 }
36116 defer func () {
37- logger .Infof ("Auto-detected mode as '%v'" , rmode )
117+ m . logger .Infof ("Auto-detected mode as '%v'" , rmode )
38118 }()
39119
40- if image .OnlyFullyQualifiedCDIDevices () {
41- return "cdi"
120+ if m . image .OnlyFullyQualifiedCDIDevices () {
121+ return CDIRuntimeMode
42122 }
43123
44124 nvinfo := info .New (
45- info .WithLogger (logger ),
46- info .WithPropertyExtractor (propertyExtractor ),
125+ info .WithLogger (m . logger ),
126+ info .WithPropertyExtractor (m . propertyExtractor ),
47127 )
48128
49129 switch nvinfo .ResolvePlatform () {
50130 case info .PlatformNVML , info .PlatformWSL :
51- return "legacy"
131+ return m . defaultMode
52132 case info .PlatformTegra :
53- return "csv"
133+ return CSVRuntimeMode
54134 }
55- return "legacy"
135+ return m . defaultMode
56136}
0 commit comments