Skip to content

Commit 255c1ed

Browse files
committed
feat(models): add model_labels field and update test data for models
1 parent 0509013 commit 255c1ed

File tree

4 files changed

+232
-0
lines changed

4 files changed

+232
-0
lines changed

internal/orchestrator/modelsindex/models_index.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type AIModel struct {
4848
ModuleDescription string `yaml:"description"`
4949
Runner string `yaml:"runner"`
5050
Bricks []string `yaml:"bricks,omitempty"`
51+
ModelLabels []string `yaml:"model_labels,omitempty"`
5152
Metadata map[string]string `yaml:"metadata,omitempty"`
5253
ModelConfiguration map[string]string `yaml:"model_configuration,omitempty"`
5354
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package modelsindex
2+
3+
import (
4+
"testing"
5+
6+
"github.com/arduino/go-paths-helper"
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestGenerateModelsIndexFromFile(t *testing.T) {
12+
testdataPath := paths.New("testdata")
13+
14+
t.Run("Valid Model list", func(t *testing.T) {
15+
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath)
16+
require.NoError(t, err)
17+
require.NotNil(t, modelsIndex)
18+
19+
models := modelsIndex.GetModels()
20+
assert.Len(t, models, 3, "Expected 3 models to be parsed")
21+
22+
// Test first model
23+
model1, found := modelsIndex.GetModelByID("face-detection")
24+
assert.Equal(t, "brick", model1.Runner)
25+
require.True(t, found, "face-detection should be found")
26+
assert.Equal(t, "face-detection:", model1.ID)
27+
assert.Equal(t, "Lightweight-Face-Detection", model1.Name)
28+
assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model1.ModuleDescription)
29+
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.L)
30+
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.Bricks)
31+
assert.Equal(t, "1.0.0", model1.Metadata["version"])
32+
assert.Equal(t, "Test Author", model1.Metadata["author"])
33+
assert.Equal(t, "1000", model1.ModelConfiguration["max_tokens"])
34+
assert.Equal(t, "0.7", model1.ModelConfiguration["temperature"])
35+
36+
// // Test second model
37+
// model2, found := modelsIndex.GetModelByID("test_model_2")
38+
// // require.True(t, found, "test_model_2 should be found")
39+
// // assert.Equal(t, "test_model_2", model2.ID)
40+
// // assert.Equal(t, "Test Model 2", model2.Name)
41+
// // assert.Equal(t, "Another test AI model", model2.ModuleDescription)
42+
// // assert.Equal(t, "another_runner", model2.Runner)
43+
// // assert.Equal(t, []string{"brick2", "brick3"}, model2.Bricks)
44+
// // assert.Equal(t, "2.0.0", model2.Metadata["version"])
45+
// // assert.Equal(t, "MIT", model2.Metadata["license"])
46+
47+
// // Test minimal model
48+
// model3, found := modelsIndex.GetModelByID("minimal_model")
49+
// require.True(t, found, "minimal_model should be found")
50+
// assert.Equal(t, "minimal_model", model3.ID)
51+
// assert.Equal(t, "Minimal Model", model3.Name)
52+
// assert.Equal(t, "Minimal model with no optional fields", model3.ModuleDescription)
53+
// assert.Equal(t, "minimal_runner", model3.Runner)
54+
// assert.Empty(t, model3.Bricks)
55+
// assert.Empty(t, model3.Metadata)
56+
// assert.Empty(t, model3.ModelConfiguration)
57+
})
58+
59+
// Test file not found error
60+
t.Run("FileNotFound", func(t *testing.T) {
61+
nonExistentPath := paths.New("nonexistent")
62+
modelsIndex, err := GenerateModelsIndexFromFile(nonExistentPath)
63+
assert.Error(t, err)
64+
assert.Nil(t, modelsIndex)
65+
})
66+
67+
// Test invalid YAML parsing
68+
t.Run("InvalidYAML", func(t *testing.T) {
69+
// Create a temporary invalid YAML file
70+
invalidPath := testdataPath.Join("invalid-models.yaml")
71+
72+
// We expect this to either fail parsing or handle gracefully
73+
// Since the current implementation may be lenient with missing fields
74+
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath.Parent().Join("testdata-invalid"))
75+
if err != nil {
76+
// If it fails, that's expected for invalid files
77+
assert.Error(t, err)
78+
assert.Nil(t, modelsIndex)
79+
}
80+
// Note: Some invalid YAML might still parse successfully depending on the YAML library's behavior
81+
_ = invalidPath // Avoid unused variable warning
82+
})
83+
84+
// Test brick filtering functionality
85+
t.Run("BrickFiltering", func(t *testing.T) {
86+
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath)
87+
require.NoError(t, err)
88+
89+
// Test GetModelsByBrick
90+
brick1Models := modelsIndex.GetModelsByBrick("brick1")
91+
assert.Len(t, brick1Models, 1)
92+
assert.Equal(t, "test_model_1", brick1Models[0].ID)
93+
94+
brick2Models := modelsIndex.GetModelsByBrick("brick2")
95+
assert.Len(t, brick2Models, 2)
96+
modelIDs := []string{brick2Models[0].ID, brick2Models[1].ID}
97+
assert.Contains(t, modelIDs, "test_model_1")
98+
assert.Contains(t, modelIDs, "test_model_2")
99+
100+
// Test GetModelsByBricks
101+
multiModels := modelsIndex.GetModelsByBricks([]string{"brick1", "brick3"})
102+
assert.Len(t, multiModels, 2)
103+
multiModelIDs := []string{multiModels[0].ID, multiModels[1].ID}
104+
assert.Contains(t, multiModelIDs, "test_model_1")
105+
assert.Contains(t, multiModelIDs, "test_model_2")
106+
107+
// Test non-existent brick
108+
nonExistentModels := modelsIndex.GetModelsByBrick("nonexistent_brick")
109+
assert.Nil(t, nonExistentModels)
110+
})
111+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
models:
2+
- invalid_model:
3+
name: "Invalid Model"
4+
description: "Missing required fields"
5+
# Missing runner field
6+
invalid_field: "this should cause parsing issues"
7+
- another_invalid:
8+
name: 123 # Invalid type for name field
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
models:
2+
- face-detection:
3+
runner: brick
4+
name : "Lightweight-Face-Detection"
5+
description: "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images."
6+
model_configuration:
7+
"EI_OBJ_DETECTION_MODEL": "/models/ootb/ei/lw-face-det.eim"
8+
model_labels:
9+
- face
10+
bricks:
11+
- arduino:object_detection
12+
- arduino:video_object_detection
13+
metadata:
14+
source: "qualcomm-ai-hub"
15+
ei-gpu-mode: false
16+
source-model-id: "face-det-lite"
17+
source-model-url: "https://aihub.qualcomm.com/models/face_det_lite"
18+
- yolox-object-detection:
19+
runner: brick
20+
name : "General purpose object detection - YoloX"
21+
description: "General purpose object detection model based on YoloX Nano. This model is trained on the COCO dataset and can detect 80 different object classes."
22+
model_configuration:
23+
"EI_OBJ_DETECTION_MODEL": "/models/ootb/ei/yolo-x-nano.eim"
24+
model_labels:
25+
- airplane
26+
- apple
27+
- backpack
28+
- banana
29+
- baseball bat
30+
- baseball glove
31+
- bear
32+
- bed
33+
- bench
34+
- bicycle
35+
- bird
36+
- boat
37+
- book
38+
- bottle
39+
- bowl
40+
- broccoli
41+
- bus
42+
- cake
43+
- car
44+
- carrot
45+
- cat
46+
- cell phone
47+
- chair
48+
- clock
49+
- couch
50+
- cow
51+
- cup
52+
- dining table
53+
- dog
54+
- donut
55+
- elephant
56+
- fire hydrant
57+
- fork
58+
- frisbee
59+
- giraffe
60+
- hair drier
61+
- handbag
62+
- hot dog
63+
- horse
64+
- keyboard
65+
- kite
66+
- knife
67+
- laptop
68+
- microwave
69+
- motorcycle
70+
- mouse
71+
- orange
72+
- oven
73+
- parking meter
74+
- person
75+
- pizza
76+
- potted plant
77+
- refrigerator
78+
- remote
79+
- sandwich
80+
- scissors
81+
- sheep
82+
- sink
83+
- skateboard
84+
- skis
85+
- snowboard
86+
- spoon
87+
- sports ball
88+
- stop sign
89+
- suitcase
90+
- surfboard
91+
- teddy bear
92+
- tennis racket
93+
- tie
94+
- toaster
95+
- toilet
96+
- toothbrush
97+
- traffic light
98+
- train
99+
- truck
100+
- tv
101+
- umbrella
102+
- vase
103+
- wine glass
104+
- zebra
105+
metadata:
106+
source: "edgeimpulse"
107+
ei-project-id: 717280
108+
source-model-id: "YOLOX-Nano"
109+
source-model-url: "https://github.com/Megvii-BaseDetection/YOLOX"
110+
bricks:
111+
- arduino:object_detection
112+
- arduino:video_object_detection

0 commit comments

Comments
 (0)