@@ -19,41 +19,74 @@ package nvcdi
1919import (
2020 "fmt"
2121
22- "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322 "tags.cncf.io/container-device-interface/pkg/cdi"
2423 "tags.cncf.io/container-device-interface/specs-go"
2524
25+ "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
26+ "github.com/NVIDIA/go-nvml/pkg/nvml"
27+
2628 "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2729 "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2830 "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
2931)
3032
31- // GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
32- func (l * nvmllib ) GetGPUDeviceSpecs (i int , d device.Device ) ([]specs.Device , error ) {
33- edits , err := l .GetGPUDeviceEdits (d )
33+ // A fullGPUDeviceSpecGenerator generates the CDI device specifications for a
34+ // single full GPU.
35+ type fullGPUDeviceSpecGenerator struct {
36+ * nvmllib
37+ id string
38+ index int
39+ device device.Device
40+ }
41+
42+ var _ DeviceSpecGenerator = (* fullGPUDeviceSpecGenerator )(nil )
43+
44+ func (l * nvmllib ) newFullGPUDeviceSpecGeneratorFromNVMLDevice (id string , nvmlDevice nvml.Device ) (DeviceSpecGenerator , error ) {
45+ device , err := l .devicelib .NewDevice (nvmlDevice )
3446 if err != nil {
35- return nil , fmt . Errorf ( "failed to get edits for device: %v" , err )
47+ return nil , err
3648 }
3749
38- var deviceSpecs []specs.Device
39- names , err := l .deviceNamers .GetDeviceNames (i , convert {d })
50+ index , ret := nvmlDevice .GetIndex ()
51+ if ret != nvml .SUCCESS {
52+ return nil , fmt .Errorf ("failed to get device index: %v" , ret )
53+ }
54+
55+ e := & fullGPUDeviceSpecGenerator {
56+ nvmllib : l ,
57+ id : id ,
58+ index : index ,
59+ device : device ,
60+ }
61+ return e , nil
62+ }
63+
64+ func (l * fullGPUDeviceSpecGenerator ) GetDeviceSpecs () ([]specs.Device , error ) {
65+ deviceEdits , err := l .getDeviceEdits ()
66+ if err != nil {
67+ return nil , fmt .Errorf ("failed to get CDI device edits for identifier %q: %w" , l .id , err )
68+ }
69+
70+ names , err := l .getNames ()
4071 if err != nil {
41- return nil , fmt .Errorf ("failed to get device name : %v " , err )
72+ return nil , fmt .Errorf ("failed to get device names : %w " , err )
4273 }
74+
75+ var deviceSpecs []specs.Device
4376 for _ , name := range names {
44- spec := specs.Device {
77+ deviceSpec := specs.Device {
4578 Name : name ,
46- ContainerEdits : * edits .ContainerEdits ,
79+ ContainerEdits : * deviceEdits .ContainerEdits ,
4780 }
48- deviceSpecs = append (deviceSpecs , spec )
81+ deviceSpecs = append (deviceSpecs , deviceSpec )
4982 }
5083
5184 return deviceSpecs , nil
5285}
5386
5487// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
55- func (l * nvmllib ) GetGPUDeviceEdits ( d device. Device ) (* cdi.ContainerEdits , error ) {
56- device , err := l .newFullGPUDiscoverer (d )
88+ func (l * fullGPUDeviceSpecGenerator ) getDeviceEdits ( ) (* cdi.ContainerEdits , error ) {
89+ device , err := l .newFullGPUDiscoverer (l . device )
5790 if err != nil {
5891 return nil , fmt .Errorf ("failed to create device discoverer: %v" , err )
5992 }
@@ -66,8 +99,12 @@ func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error
6699 return editsForDevice , nil
67100}
68101
102+ func (l * fullGPUDeviceSpecGenerator ) getNames () ([]string , error ) {
103+ return l .deviceNamers .GetDeviceNames (l .index , convert {l .device })
104+ }
105+
69106// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.
70- func (l * nvmllib ) newFullGPUDiscoverer (d device.Device ) (discover.Discover , error ) {
107+ func (l * fullGPUDeviceSpecGenerator ) newFullGPUDiscoverer (d device.Device ) (discover.Discover , error ) {
71108 deviceNodes , err := dgpu .NewForDevice (d ,
72109 dgpu .WithDevRoot (l .devRoot ),
73110 dgpu .WithLogger (l .logger ),
0 commit comments