diff --git a/pkg/nvcdi/lib-vfio.go b/pkg/nvcdi/lib-vfio.go new file mode 100644 index 000000000..facd40c85 --- /dev/null +++ b/pkg/nvcdi/lib-vfio.go @@ -0,0 +1,122 @@ +/** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + "path/filepath" + "strconv" + + "tags.cncf.io/container-device-interface/pkg/cdi" + "tags.cncf.io/container-device-interface/specs-go" +) + +type vfiolib nvcdilib + +type vfioDevice struct { + index int + group int + devRoot string +} + +var _ deviceSpecGeneratorFactory = (*vfiolib)(nil) + +func (l *vfiolib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) { + vfioDevices, err := l.getVfioDevices(ids...) + if err != nil { + return nil, err + } + var deviceSpecGenerators DeviceSpecGenerators + for _, vfioDevice := range vfioDevices { + deviceSpecGenerators = append(deviceSpecGenerators, vfioDevice) + } + + return deviceSpecGenerators, nil +} + +// GetDeviceSpecs returns the CDI device specs for a vfio device. +func (l *vfioDevice) GetDeviceSpecs() ([]specs.Device, error) { + path := fmt.Sprintf("/dev/vfio/%d", l.group) + deviceSpec := specs.Device{ + Name: fmt.Sprintf("%d", l.index), + ContainerEdits: specs.ContainerEdits{ + DeviceNodes: []*specs.DeviceNode{ + { + Path: path, + HostPath: filepath.Join(l.devRoot, path), + }, + }, + }, + } + return []specs.Device{deviceSpec}, nil +} + +// GetCommonEdits returns common edits for ALL devices. +// Note, currently there are no common edits. +func (l *vfiolib) GetCommonEdits() (*cdi.ContainerEdits, error) { + e := cdi.ContainerEdits{ + ContainerEdits: &specs.ContainerEdits{ + DeviceNodes: []*specs.DeviceNode{ + { + Path: "/dev/vfio/vfio", + HostPath: filepath.Join(l.devRoot, "/dev/vfio/vfio"), + }, + }, + }, + } + return &e, nil +} + +func (l *vfiolib) getVfioDevices(ids ...string) ([]*vfioDevice, error) { + var vfioDevices []*vfioDevice + for _, id := range ids { + if id == "all" { + return l.getAllVfioDevices() + } + index, err := strconv.ParseInt(id, 10, 32) + if err != nil { + return nil, fmt.Errorf("invalid channel ID %v: %w", id, err) + } + i := int(index) + dev, err := l.nvpcilib.GetGPUByIndex(i) + if err != nil { + return nil, fmt.Errorf("failed to get device: %w", err) + } + vfioDevices = append(vfioDevices, &vfioDevice{index: i, group: dev.IommuGroup, devRoot: l.devRoot}) + } + + return vfioDevices, nil +} + +func (l *vfiolib) getAllVfioDevices() ([]*vfioDevice, error) { + devices, err := l.nvpcilib.GetGPUs() + if err != nil { + return nil, fmt.Errorf("failed getting NVIDIA GPUs: %v", err) + } + + var vfioDevices []*vfioDevice + for i, dev := range devices { + if dev.Driver != "vfio-pci" { + continue + } + l.logger.Debugf("Found NVIDIA device: address=%s, driver=%s, iommu_group=%d, deviceId=%x", + dev.Address, dev.Driver, dev.IommuGroup, dev.Device) + vfioDevices = append(vfioDevices, &vfioDevice{index: i, group: dev.IommuGroup, devRoot: l.devRoot}) + } + return vfioDevices, nil +} diff --git a/pkg/nvcdi/lib-vfio_test.go b/pkg/nvcdi/lib-vfio_test.go new file mode 100644 index 000000000..0b638c7fb --- /dev/null +++ b/pkg/nvcdi/lib-vfio_test.go @@ -0,0 +1,123 @@ +/** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "bytes" + "testing" + + "github.com/NVIDIA/go-nvlib/pkg/nvpci" + "github.com/stretchr/testify/require" +) + +func TestModeVfio(t *testing.T) { + testCases := []struct { + description string + pcilib *nvpci.InterfaceMock + ids []string + expectedError error + expectedSpec string + }{ + { + description: "get all specs single device", + pcilib: &nvpci.InterfaceMock{ + GetGPUsFunc: func() ([]*nvpci.NvidiaPCIDevice, error) { + devices := []*nvpci.NvidiaPCIDevice{ + { + Driver: "vfio-pci", + IommuGroup: 5, + }, + } + return devices, nil + }, + }, + expectedSpec: `--- +cdiVersion: 0.5.0 +kind: nvidia.com/pgpu +devices: + - name: "0" + containerEdits: + deviceNodes: + - path: /dev/vfio/5 + hostPath: /dev/vfio/5 +containerEdits: + env: + - NVIDIA_VISIBLE_DEVICES=void + deviceNodes: + - path: /dev/vfio/vfio + hostPath: /dev/vfio/vfio +`, + }, + { + description: "get single device spec by index", + pcilib: &nvpci.InterfaceMock{ + GetGPUByIndexFunc: func(n int) (*nvpci.NvidiaPCIDevice, error) { + devices := []*nvpci.NvidiaPCIDevice{ + { + Driver: "vfio-pci", + IommuGroup: 45, + }, + { + Driver: "vfio-pci", + IommuGroup: 5, + }, + } + return devices[n], nil + }, + }, + ids: []string{"1"}, + expectedSpec: `--- +cdiVersion: 0.5.0 +kind: nvidia.com/pgpu +devices: + - name: "1" + containerEdits: + deviceNodes: + - path: /dev/vfio/5 + hostPath: /dev/vfio/5 +containerEdits: + env: + - NVIDIA_VISIBLE_DEVICES=void + deviceNodes: + - path: /dev/vfio/vfio + hostPath: /dev/vfio/vfio +`, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + lib, err := New( + WithMode(ModeVfio), + WithPCILib(tc.pcilib), + ) + require.NoError(t, err) + + spec, err := lib.GetSpec(tc.ids...) + require.EqualValues(t, tc.expectedError, err) + + var output bytes.Buffer + + _, err = spec.WriteTo(&output) + require.NoError(t, err) + + require.Equal(t, tc.expectedSpec, output.String()) + }) + } + +} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 4369a7215..3976f8763 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -21,6 +21,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" + "github.com/NVIDIA/go-nvlib/pkg/nvpci" "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" @@ -54,6 +55,8 @@ type nvcdilib struct { driver *root.Driver infolib info.Interface + nvpcilib nvpci.Interface + mergedDeviceOptions []transform.MergedDeviceOption featureFlags map[FeatureFlag]bool @@ -140,6 +143,14 @@ func New(opts ...Option) (Interface, error) { l.class = classImexChannel } factory = (*imexlib)(l) + case ModeVfio: + if l.class == "" { + l.class = "pgpu" + } + if l.nvpcilib == nil { + l.nvpcilib = nvpci.New() + } + factory = (*vfiolib)(l) default: return nil, fmt.Errorf("unknown mode %q", l.mode) } diff --git a/pkg/nvcdi/mode.go b/pkg/nvcdi/mode.go index a68170ece..a19f472c3 100644 --- a/pkg/nvcdi/mode.go +++ b/pkg/nvcdi/mode.go @@ -46,6 +46,8 @@ const ( ModeImex = Mode("imex") // ModeNvswitch configures the CDI spec generator to generate a spec for the available nvswitch devices. ModeNvswitch = Mode("nvswitch") + // ModeVfio configures the CDI spec generator to generate a VFIO spec. + ModeVfio = Mode("vfio") ) type modeConstraint interface { @@ -72,6 +74,7 @@ func getModes() modes { ModeMofed, ModeNvml, ModeNvswitch, + ModeVfio, ModeWsl, } lookup := make(map[Mode]bool) diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 550b18bad..f033f04c0 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -19,6 +19,7 @@ package nvcdi import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" + "github.com/NVIDIA/go-nvlib/pkg/nvpci" "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" @@ -43,6 +44,13 @@ func WithInfoLib(infolib info.Interface) Option { } } +// WithPCILib sets the PCI library to be used for CDI spec generation. +func WithPCILib(pcilib nvpci.Interface) Option { + return func(l *nvcdilib) { + l.nvpcilib = pcilib + } +} + // WithDeviceNamers sets the device namer for the library func WithDeviceNamers(namers ...DeviceNamer) Option { return func(l *nvcdilib) {