diff --git a/cmd/nvidia-ctk-installer/toolkit/toolkit.go b/cmd/nvidia-ctk-installer/toolkit/toolkit.go index 73d887572..bf970473b 100644 --- a/cmd/nvidia-ctk-installer/toolkit/toolkit.go +++ b/cmd/nvidia-ctk-installer/toolkit/toolkit.go @@ -30,9 +30,9 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit/installer" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" transformroot "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform/root" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvdevices" ) const ( diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go index aff17ada2..c513b1898 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go +++ b/cmd/nvidia-ctk/system/create-dev-char-symlinks/create-dev-char-symlinks.go @@ -19,14 +19,12 @@ package devchar import ( "context" "fmt" - "os" - "path/filepath" "github.com/urfave/cli/v3" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvdevices" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvmodules" ) const ( @@ -155,7 +153,6 @@ func (m command) run(cfg *config) error { type linkCreator struct { logger logger.Interface - lister nodeLister driverRoot string devRoot string devCharPath string @@ -163,6 +160,8 @@ type linkCreator struct { createAll bool createDeviceNodes bool loadKernelModules bool + + devicesLib *nvdevices.Interface } // Creator is an interface for creating symlinks to /dev/nv* devices in /dev/char. @@ -174,6 +173,8 @@ type Creator interface { type Option func(*linkCreator) // NewSymlinkCreator creates a new linkCreator. +// +// Deprecated: Use the `nvdevices` package instead. func NewSymlinkCreator(opts ...Option) (Creator, error) { c := linkCreator{} for _, opt := range opts { @@ -192,52 +193,34 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) { c.devCharPath = defaultDevCharPath } - if err := c.setup(); err != nil { - return nil, err - } - - if c.createAll { - lister, err := newAllPossible(c.logger, c.devRoot) - if err != nil { - return nil, fmt.Errorf("failed to create all possible device lister: %v", err) - } - c.lister = lister - } else { - c.lister = existing{c.logger, c.devRoot} - } - return c, nil -} - -func (m linkCreator) setup() error { - if !m.loadKernelModules && !m.createDeviceNodes { - return nil - } - - if m.loadKernelModules { + if c.loadKernelModules { modules := nvmodules.New( - nvmodules.WithLogger(m.logger), - nvmodules.WithDryRun(m.dryRun), - nvmodules.WithRoot(m.driverRoot), + nvmodules.WithLogger(c.logger), + nvmodules.WithDryRun(c.dryRun), + nvmodules.WithRoot(c.driverRoot), ) if err := modules.LoadAll(); err != nil { - return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) + return nil, fmt.Errorf("failed to load NVIDIA kernel modules: %v", err) } } - if m.createDeviceNodes { - devices, err := nvdevices.New( - nvdevices.WithLogger(m.logger), - nvdevices.WithDryRun(m.dryRun), - nvdevices.WithDevRoot(m.devRoot), - ) - if err != nil { - return err - } + devices, err := nvdevices.New( + nvdevices.WithLogger(c.logger), + nvdevices.WithDryRun(c.dryRun), + nvdevices.WithDevRoot(c.driverRoot), + ) + if err != nil { + return nil, err + } + c.devicesLib = devices + + if c.createDeviceNodes { if err := devices.CreateNVIDIAControlDevices(); err != nil { - return fmt.Errorf("failed to create NVIDIA device nodes: %v", err) + return nil, fmt.Errorf("failed to create NVIDIA device nodes: %v", err) } } - return nil + + return c, nil } // WithDriverRoot sets the driver root path. @@ -299,42 +282,5 @@ func WithCreateDeviceNodes(createDeviceNodes bool) Option { // CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root. func (m linkCreator) CreateLinks() error { - deviceNodes, err := m.lister.DeviceNodes() - if err != nil { - return fmt.Errorf("failed to get device nodes: %v", err) - } - - if len(deviceNodes) != 0 && !m.dryRun { - err := os.MkdirAll(m.devCharPath, 0755) - if err != nil { - return fmt.Errorf("failed to create directory %s: %v", m.devCharPath, err) - } - } - - for _, deviceNode := range deviceNodes { - target := deviceNode.path - linkPath := filepath.Join(m.devCharPath, deviceNode.devCharName()) - - m.logger.Infof("Creating link %s => %s", linkPath, target) - if m.dryRun { - continue - } - - err = os.Symlink(target, linkPath) - if err != nil { - m.logger.Warningf("Could not create symlink: %v", err) - } - } - - return nil -} - -type deviceNode struct { - path string - major uint32 - minor uint32 -} - -func (d deviceNode) devCharName() string { - return fmt.Sprintf("%d:%d", d.major, d.minor) + return m.devicesLib.CreateDevCharSymlinks(m.devCharPath, !m.createAll) } diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go deleted file mode 100644 index d022a98fe..000000000 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go +++ /dev/null @@ -1,89 +0,0 @@ -/** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package devchar - -import ( - "path/filepath" - "strings" - - "golang.org/x/sys/unix" - - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" -) - -type nodeLister interface { - DeviceNodes() ([]deviceNode, error) -} - -type existing struct { - logger logger.Interface - devRoot string -} - -// DeviceNodes returns a list of NVIDIA device nodes in the specified root. -// The nvidia-nvswitch* and nvidia-nvlink devices are excluded. -func (m existing) DeviceNodes() ([]deviceNode, error) { - locator := lookup.NewCharDeviceLocator( - lookup.WithLogger(m.logger), - lookup.WithRoot(m.devRoot), - lookup.WithOptional(true), - ) - - devices, err := locator.Locate("/dev/nvidia*") - if err != nil { - m.logger.Warningf("Error while locating device: %v", err) - } - - capDevices, err := locator.Locate("/dev/nvidia-caps/nvidia-*") - if err != nil { - m.logger.Warningf("Error while locating caps device: %v", err) - } - - if len(devices) == 0 && len(capDevices) == 0 { - m.logger.Infof("No NVIDIA devices found in %s", m.devRoot) - return nil, nil - } - - var deviceNodes []deviceNode - for _, d := range append(devices, capDevices...) { - if m.nodeIsBlocked(d) { - continue - } - var stat unix.Stat_t - err := unix.Stat(d, &stat) - if err != nil { - m.logger.Warningf("Could not stat device: %v", err) - continue - } - deviceNodes = append(deviceNodes, newDeviceNode(d, stat)) - } - - return deviceNodes, nil -} - -// nodeIsBlocked returns true if the specified device node should be ignored. -func (m existing) nodeIsBlocked(path string) bool { - blockedPrefixes := []string{"nvidia-fs", "nvidia-nvswitch", "nvidia-nvlink"} - nodeName := filepath.Base(path) - for _, prefix := range blockedPrefixes { - if strings.HasPrefix(nodeName, prefix) { - return true - } - } - return false -} diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_linux.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_linux.go deleted file mode 100644 index 4aab942af..000000000 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_linux.go +++ /dev/null @@ -1,28 +0,0 @@ -/** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package devchar - -import "golang.org/x/sys/unix" - -func newDeviceNode(d string, stat unix.Stat_t) deviceNode { - deviceNode := deviceNode{ - path: d, - major: unix.Major(stat.Rdev), - minor: unix.Minor(stat.Rdev), - } - return deviceNode -} diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_other.go b/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_other.go deleted file mode 100644 index 9be96294b..000000000 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_other.go +++ /dev/null @@ -1,30 +0,0 @@ -//go:build !linux - -/** -# Copyright (c) NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package devchar - -import "golang.org/x/sys/unix" - -func newDeviceNode(d string, stat unix.Stat_t) deviceNode { - deviceNode := deviceNode{ - path: d, - major: unix.Major(uint64(stat.Rdev)), - minor: unix.Minor(uint64(stat.Rdev)), - } - return deviceNode -} diff --git a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go index 73b1fb4b1..bac5923d9 100644 --- a/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go +++ b/cmd/nvidia-ctk/system/create-device-nodes/create-device-nodes.go @@ -23,8 +23,8 @@ import ( "github.com/urfave/cli/v3" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvdevices" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvmodules" ) type command struct { diff --git a/internal/lookup/locator.go b/internal/lookup/locator.go index 9906327c4..b73d1d41f 100644 --- a/internal/lookup/locator.go +++ b/internal/lookup/locator.go @@ -27,3 +27,13 @@ type Locator interface { // ErrNotFound indicates that a specified pattern or file could not be found. var ErrNotFound = errors.New("not found") + +type always string + +const Always = always("always") + +var _ Locator = (*always)(nil) + +func (l always) Locate(s string) ([]string, error) { + return []string{s}, nil +} diff --git a/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go b/pkg/system/nvdevices/all.go similarity index 95% rename from cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go rename to pkg/system/nvdevices/all.go index cafb8f9c7..c79ffc153 100644 --- a/cmd/nvidia-ctk/system/create-dev-char-symlinks/all.go +++ b/pkg/system/nvdevices/all.go @@ -14,7 +14,7 @@ # limitations under the License. **/ -package devchar +package nvdevices import ( "fmt" @@ -34,6 +34,20 @@ type allPossible struct { migCaps nvcaps.MigCaps } +type nodeLister interface { + DeviceNodes() ([]deviceNode, error) +} + +type deviceNode struct { + path string + major uint32 + minor uint32 +} + +func (d deviceNode) devCharName() string { + return fmt.Sprintf("%d:%d", d.major, d.minor) +} + // newAllPossible returns a new allPossible device node lister. // This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices. func newAllPossible(logger logger.Interface, devRoot string) (nodeLister, error) { diff --git a/internal/system/nvdevices/devices.go b/pkg/system/nvdevices/devices.go similarity index 68% rename from internal/system/nvdevices/devices.go rename to pkg/system/nvdevices/devices.go index f667f6b76..3d250cb1e 100644 --- a/internal/system/nvdevices/devices.go +++ b/pkg/system/nvdevices/devices.go @@ -19,11 +19,13 @@ package nvdevices import ( "errors" "fmt" + "os" "path/filepath" "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" ) var errInvalidDeviceNode = errors.New("invalid device node") @@ -102,6 +104,67 @@ func (m *Interface) CreateNVIDIADevice(node string) error { return m.createDeviceNode(filepath.Join("dev", node), int(major), int(minor)) } +// CreateDevCharSymlinks creates symlinks at the specified path NVIDIA device nodes. +// If existingOnly is set to false, symlinks will be created for ALL possible devices. +func (m *Interface) CreateDevCharSymlinks(devCharPath string, existingOnly bool) error { + if devCharPath == "" || devCharPath == "/" { + return fmt.Errorf("invalid /dev/char path: %q", devCharPath) + } + lister, err := newAllPossible(m.logger, m.devRoot) + if err != nil { + return fmt.Errorf("failed to create all possible device lister: %v", err) + } + + deviceNodes, err := lister.DeviceNodes() + if err != nil { + return fmt.Errorf("failed to get device nodes: %v", err) + } + + var deviceNodeLocator lookup.Locator + if existingOnly { + deviceNodeLocator = lookup.NewCharDeviceLocator( + lookup.WithLogger(m.logger), + lookup.WithRoot(m.devRoot), + lookup.WithCount(1), + lookup.WithOptional(true), + ) + } else { + deviceNodeLocator = lookup.Always + } + + var parentCreated bool + for _, deviceNode := range deviceNodes { + target := deviceNode.path + // TODO: This assumes that the majors for the kernel modules align with + // the majors for the actual device nodes. + linkPath := filepath.Join(devCharPath, deviceNode.devCharName()) + + candidates, err := deviceNodeLocator.Locate(target) + if err != nil || len(candidates) == 0 { + m.logger.Debugf("Ignoring non-existing device node %q", target) + } + + m.logger.Infof("Creating link %s => %s", linkPath, target) + if m.dryRun { + continue + } + + if !parentCreated { + err := os.MkdirAll(devCharPath, 0755) + if err != nil { + return fmt.Errorf("failed to create directory %s: %v", devCharPath, err) + } + parentCreated = true + } + + if err := os.Symlink(target, linkPath); err != nil { + m.logger.Warningf("Could not create symlink: %v", err) + } + } + + return nil +} + // createDeviceNode creates the specified device node with the require major and minor numbers. // If a devRoot is configured, this is prepended to the path. func (m *Interface) createDeviceNode(path string, major int, minor int) error { diff --git a/internal/system/nvdevices/devices_test.go b/pkg/system/nvdevices/devices_test.go similarity index 100% rename from internal/system/nvdevices/devices_test.go rename to pkg/system/nvdevices/devices_test.go diff --git a/internal/system/nvdevices/mknod.go b/pkg/system/nvdevices/mknod.go similarity index 100% rename from internal/system/nvdevices/mknod.go rename to pkg/system/nvdevices/mknod.go diff --git a/internal/system/nvdevices/mknod_mock.go b/pkg/system/nvdevices/mknod_mock.go similarity index 100% rename from internal/system/nvdevices/mknod_mock.go rename to pkg/system/nvdevices/mknod_mock.go diff --git a/internal/system/nvdevices/options.go b/pkg/system/nvdevices/options.go similarity index 100% rename from internal/system/nvdevices/options.go rename to pkg/system/nvdevices/options.go diff --git a/internal/system/nvmodules/cmd.go b/pkg/system/nvmodules/cmd.go similarity index 100% rename from internal/system/nvmodules/cmd.go rename to pkg/system/nvmodules/cmd.go diff --git a/internal/system/nvmodules/cmd_mock.go b/pkg/system/nvmodules/cmd_mock.go similarity index 100% rename from internal/system/nvmodules/cmd_mock.go rename to pkg/system/nvmodules/cmd_mock.go diff --git a/internal/system/nvmodules/modules.go b/pkg/system/nvmodules/modules.go similarity index 100% rename from internal/system/nvmodules/modules.go rename to pkg/system/nvmodules/modules.go diff --git a/internal/system/nvmodules/modules_test.go b/pkg/system/nvmodules/modules_test.go similarity index 100% rename from internal/system/nvmodules/modules_test.go rename to pkg/system/nvmodules/modules_test.go diff --git a/internal/system/nvmodules/options.go b/pkg/system/nvmodules/options.go similarity index 100% rename from internal/system/nvmodules/options.go rename to pkg/system/nvmodules/options.go