Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/nvidia-ctk-installer/toolkit/toolkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit/installer"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
transformroot "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform/root"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvdevices"
)

const (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ package devchar
import (
"context"
"fmt"
"os"
"path/filepath"

"github.com/urfave/cli/v3"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvmodules"
)

const (
Expand Down Expand Up @@ -155,14 +153,15 @@ func (m command) run(cfg *config) error {

type linkCreator struct {
logger logger.Interface
lister nodeLister
driverRoot string
devRoot string
devCharPath string
dryRun bool
createAll bool
createDeviceNodes bool
loadKernelModules bool

devicesLib *nvdevices.Interface
}

// Creator is an interface for creating symlinks to /dev/nv* devices in /dev/char.
Expand All @@ -174,6 +173,8 @@ type Creator interface {
type Option func(*linkCreator)

// NewSymlinkCreator creates a new linkCreator.
//
// Deprecated: Use the `nvdevices` package instead.
func NewSymlinkCreator(opts ...Option) (Creator, error) {
c := linkCreator{}
for _, opt := range opts {
Expand All @@ -192,52 +193,34 @@ func NewSymlinkCreator(opts ...Option) (Creator, error) {
c.devCharPath = defaultDevCharPath
}

if err := c.setup(); err != nil {
return nil, err
}

if c.createAll {
lister, err := newAllPossible(c.logger, c.devRoot)
if err != nil {
return nil, fmt.Errorf("failed to create all possible device lister: %v", err)
}
c.lister = lister
} else {
c.lister = existing{c.logger, c.devRoot}
}
return c, nil
}

func (m linkCreator) setup() error {
if !m.loadKernelModules && !m.createDeviceNodes {
return nil
}

if m.loadKernelModules {
if c.loadKernelModules {
modules := nvmodules.New(
nvmodules.WithLogger(m.logger),
nvmodules.WithDryRun(m.dryRun),
nvmodules.WithRoot(m.driverRoot),
nvmodules.WithLogger(c.logger),
nvmodules.WithDryRun(c.dryRun),
nvmodules.WithRoot(c.driverRoot),
)
if err := modules.LoadAll(); err != nil {
return fmt.Errorf("failed to load NVIDIA kernel modules: %v", err)
return nil, fmt.Errorf("failed to load NVIDIA kernel modules: %v", err)
}
}

if m.createDeviceNodes {
devices, err := nvdevices.New(
nvdevices.WithLogger(m.logger),
nvdevices.WithDryRun(m.dryRun),
nvdevices.WithDevRoot(m.devRoot),
)
if err != nil {
return err
}
devices, err := nvdevices.New(
nvdevices.WithLogger(c.logger),
nvdevices.WithDryRun(c.dryRun),
nvdevices.WithDevRoot(c.driverRoot),
)
if err != nil {
return nil, err
}
c.devicesLib = devices

if c.createDeviceNodes {
if err := devices.CreateNVIDIAControlDevices(); err != nil {
return fmt.Errorf("failed to create NVIDIA device nodes: %v", err)
return nil, fmt.Errorf("failed to create NVIDIA device nodes: %v", err)
}
}
return nil

return c, nil
}

// WithDriverRoot sets the driver root path.
Expand Down Expand Up @@ -299,42 +282,5 @@ func WithCreateDeviceNodes(createDeviceNodes bool) Option {

// CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root.
func (m linkCreator) CreateLinks() error {
deviceNodes, err := m.lister.DeviceNodes()
if err != nil {
return fmt.Errorf("failed to get device nodes: %v", err)
}

if len(deviceNodes) != 0 && !m.dryRun {
err := os.MkdirAll(m.devCharPath, 0755)
if err != nil {
return fmt.Errorf("failed to create directory %s: %v", m.devCharPath, err)
}
}

for _, deviceNode := range deviceNodes {
target := deviceNode.path
linkPath := filepath.Join(m.devCharPath, deviceNode.devCharName())

m.logger.Infof("Creating link %s => %s", linkPath, target)
if m.dryRun {
continue
}

err = os.Symlink(target, linkPath)
if err != nil {
m.logger.Warningf("Could not create symlink: %v", err)
}
}

return nil
}

type deviceNode struct {
path string
major uint32
minor uint32
}

func (d deviceNode) devCharName() string {
return fmt.Sprintf("%d:%d", d.major, d.minor)
return m.devicesLib.CreateDevCharSymlinks(m.devCharPath, !m.createAll)
}
89 changes: 0 additions & 89 deletions cmd/nvidia-ctk/system/create-dev-char-symlinks/existing.go

This file was deleted.

28 changes: 0 additions & 28 deletions cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_linux.go

This file was deleted.

30 changes: 0 additions & 30 deletions cmd/nvidia-ctk/system/create-dev-char-symlinks/existing_other.go

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
"github.com/urfave/cli/v3"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvmodules"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/system/nvmodules"
)

type command struct {
Expand Down
10 changes: 10 additions & 0 deletions internal/lookup/locator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,13 @@ type Locator interface {

// ErrNotFound indicates that a specified pattern or file could not be found.
var ErrNotFound = errors.New("not found")

type always string

const Always = always("always")

var _ Locator = (*always)(nil)

func (l always) Locate(s string) ([]string, error) {
return []string{s}, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
**/

package devchar
package nvdevices

import (
"fmt"
Expand All @@ -34,6 +34,20 @@ type allPossible struct {
migCaps nvcaps.MigCaps
}

type nodeLister interface {
DeviceNodes() ([]deviceNode, error)
}

type deviceNode struct {
path string
major uint32
minor uint32
}

func (d deviceNode) devCharName() string {
return fmt.Sprintf("%d:%d", d.major, d.minor)
}

// newAllPossible returns a new allPossible device node lister.
// This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices.
func newAllPossible(logger logger.Interface, devRoot string) (nodeLister, error) {
Expand Down
Loading