diff --git a/internal/orchestrator/modelsindex/models_index.go b/internal/orchestrator/modelsindex/models_index.go index a966a678..e18797f1 100644 --- a/internal/orchestrator/modelsindex/models_index.go +++ b/internal/orchestrator/modelsindex/models_index.go @@ -48,6 +48,7 @@ type AIModel struct { ModuleDescription string `yaml:"description"` Runner string `yaml:"runner"` Bricks []string `yaml:"bricks,omitempty"` + ModelLabels []string `yaml:"model_labels,omitempty"` Metadata map[string]string `yaml:"metadata,omitempty"` ModelConfiguration map[string]string `yaml:"model_configuration,omitempty"` } diff --git a/internal/orchestrator/modelsindex/modelsindex_test.go b/internal/orchestrator/modelsindex/modelsindex_test.go new file mode 100644 index 00000000..53ffb585 --- /dev/null +++ b/internal/orchestrator/modelsindex/modelsindex_test.go @@ -0,0 +1,72 @@ +package modelsindex + +import ( + "testing" + + "github.com/arduino/go-paths-helper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestModelsIndex(t *testing.T) { + modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata")) + require.NoError(t, err) + require.NotNil(t, modelsIndex) + + t.Run("it parses a valid model-list.yaml", func(t *testing.T) { + models := modelsIndex.GetModels() + assert.Len(t, models, 2, "Expected 2 models to be parsed") + }) + + t.Run("it gets a model by ID", func(t *testing.T) { + model, found := modelsIndex.GetModelByID("not-existing-model") + assert.False(t, found) + assert.Nil(t, model) + + model, found = modelsIndex.GetModelByID("face-detection") + assert.Equal(t, "brick", model.Runner) + require.True(t, found, "face-detection should be found") + assert.Equal(t, "face-detection", model.ID) + assert.Equal(t, "Lightweight-Face-Detection", model.Name) + assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model.ModuleDescription) + assert.Equal(t, []string{"face"}, model.ModelLabels) + assert.Equal(t, "/models/ootb/ei/lw-face-det.eim", model.ModelConfiguration["EI_OBJ_DETECTION_MODEL"]) + assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model.Bricks) + assert.Equal(t, "qualcomm-ai-hub", model.Metadata["source"]) + assert.Equal(t, "false", model.Metadata["ei-gpu-mode"]) + assert.Equal(t, "face-det-lite", model.Metadata["source-model-id"]) + assert.Equal(t, "https://aihub.qualcomm.com/models/face_det_lite", model.Metadata["source-model-url"]) + }) + + t.Run("it fails if model-list.yaml does not exist", func(t *testing.T) { + nonExistentPath := paths.New("nonexistentdir") + modelsIndex, err := GenerateModelsIndexFromFile(nonExistentPath) + assert.Error(t, err) + assert.Nil(t, modelsIndex) + }) + + t.Run("it gets models by a brick", func(t *testing.T) { + model := modelsIndex.GetModelsByBrick("not-existing-brick") + assert.Nil(t, model) + + model = modelsIndex.GetModelsByBrick("arduino:object_detection") + assert.Len(t, model, 1) + assert.Equal(t, "face-detection", model[0].ID) + }) + + t.Run("it gets models by bricks", func(t *testing.T) { + models := modelsIndex.GetModelsByBricks([]string{"arduino:non_existing"}) + assert.Len(t, models, 0) + assert.Nil(t, models) + + models = modelsIndex.GetModelsByBricks([]string{"arduino:video_object_detection"}) + assert.Len(t, models, 2) + assert.Equal(t, "face-detection", models[0].ID) + assert.Equal(t, "yolox-object-detection", models[1].ID) + + models = modelsIndex.GetModelsByBricks([]string{"arduino:object_detection", "arduino:video_object_detection"}) + assert.Len(t, models, 2) + assert.Equal(t, "face-detection", models[0].ID) + assert.Equal(t, "yolox-object-detection", models[1].ID) + }) +} diff --git a/internal/orchestrator/modelsindex/testdata/models-list.yaml b/internal/orchestrator/modelsindex/testdata/models-list.yaml new file mode 100644 index 00000000..7d0aefb5 --- /dev/null +++ b/internal/orchestrator/modelsindex/testdata/models-list.yaml @@ -0,0 +1,111 @@ +models: + - face-detection: + runner: brick + name : "Lightweight-Face-Detection" + description: "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images." + model_configuration: + "EI_OBJ_DETECTION_MODEL": "/models/ootb/ei/lw-face-det.eim" + model_labels: + - face + bricks: + - arduino:object_detection + - arduino:video_object_detection + metadata: + source: "qualcomm-ai-hub" + ei-gpu-mode: false + source-model-id: "face-det-lite" + source-model-url: "https://aihub.qualcomm.com/models/face_det_lite" + - yolox-object-detection: + runner: brick + name : "General purpose object detection - YoloX" + description: "General purpose object detection model based on YoloX Nano. This model is trained on the COCO dataset and can detect 80 different object classes." + model_configuration: + "EI_OBJ_DETECTION_MODEL": "/models/ootb/ei/yolo-x-nano.eim" + model_labels: + - airplane + - apple + - backpack + - banana + - baseball bat + - baseball glove + - bear + - bed + - bench + - bicycle + - bird + - boat + - book + - bottle + - bowl + - broccoli + - bus + - cake + - car + - carrot + - cat + - cell phone + - chair + - clock + - couch + - cow + - cup + - dining table + - dog + - donut + - elephant + - fire hydrant + - fork + - frisbee + - giraffe + - hair drier + - handbag + - hot dog + - horse + - keyboard + - kite + - knife + - laptop + - microwave + - motorcycle + - mouse + - orange + - oven + - parking meter + - person + - pizza + - potted plant + - refrigerator + - remote + - sandwich + - scissors + - sheep + - sink + - skateboard + - skis + - snowboard + - spoon + - sports ball + - stop sign + - suitcase + - surfboard + - teddy bear + - tennis racket + - tie + - toaster + - toilet + - toothbrush + - traffic light + - train + - truck + - tv + - umbrella + - vase + - wine glass + - zebra + metadata: + source: "edgeimpulse" + ei-project-id: 717280 + source-model-id: "YOLOX-Nano" + source-model-url: "https://github.com/Megvii-BaseDetection/YOLOX" + bricks: + - arduino:video_object_detection