Skip to content

Commit e3eec7b

Browse files
dido18lucarin91
authored andcommitted
refactor(tests): update model index tests for clarity and accuracy
1 parent 454ecb3 commit e3eec7b

File tree

3 files changed

+40
-84
lines changed

3 files changed

+40
-84
lines changed

internal/orchestrator/modelsindex/modelsindex_test.go

Lines changed: 40 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -9,102 +9,67 @@ import (
99
)
1010

1111
func TestGenerateModelsIndexFromFile(t *testing.T) {
12-
testdataPath := paths.New("testdata")
13-
14-
t.Run("Valid Model list", func(t *testing.T) {
15-
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath)
12+
t.Run("it parses a valid model-list.yaml", func(t *testing.T) {
13+
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
1614
require.NoError(t, err)
1715
require.NotNil(t, modelsIndex)
1816

1917
models := modelsIndex.GetModels()
20-
assert.Len(t, models, 3, "Expected 3 models to be parsed")
21-
22-
// Test first model
23-
model1, found := modelsIndex.GetModelByID("face-detection")
24-
assert.Equal(t, "brick", model1.Runner)
25-
require.True(t, found, "face-detection should be found")
26-
assert.Equal(t, "face-detection:", model1.ID)
27-
assert.Equal(t, "Lightweight-Face-Detection", model1.Name)
28-
assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model1.ModuleDescription)
29-
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.L)
30-
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.Bricks)
31-
assert.Equal(t, "1.0.0", model1.Metadata["version"])
32-
assert.Equal(t, "Test Author", model1.Metadata["author"])
33-
assert.Equal(t, "1000", model1.ModelConfiguration["max_tokens"])
34-
assert.Equal(t, "0.7", model1.ModelConfiguration["temperature"])
18+
assert.Len(t, models, 2, "Expected 2 models to be parsed")
19+
})
3520

36-
// // Test second model
37-
// model2, found := modelsIndex.GetModelByID("test_model_2")
38-
// // require.True(t, found, "test_model_2 should be found")
39-
// // assert.Equal(t, "test_model_2", model2.ID)
40-
// // assert.Equal(t, "Test Model 2", model2.Name)
41-
// // assert.Equal(t, "Another test AI model", model2.ModuleDescription)
42-
// // assert.Equal(t, "another_runner", model2.Runner)
43-
// // assert.Equal(t, []string{"brick2", "brick3"}, model2.Bricks)
44-
// // assert.Equal(t, "2.0.0", model2.Metadata["version"])
45-
// // assert.Equal(t, "MIT", model2.Metadata["license"])
21+
t.Run("it gets a model by ID", func(t *testing.T) {
22+
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
23+
require.NoError(t, err)
4624

47-
// // Test minimal model
48-
// model3, found := modelsIndex.GetModelByID("minimal_model")
49-
// require.True(t, found, "minimal_model should be found")
50-
// assert.Equal(t, "minimal_model", model3.ID)
51-
// assert.Equal(t, "Minimal Model", model3.Name)
52-
// assert.Equal(t, "Minimal model with no optional fields", model3.ModuleDescription)
53-
// assert.Equal(t, "minimal_runner", model3.Runner)
54-
// assert.Empty(t, model3.Bricks)
55-
// assert.Empty(t, model3.Metadata)
56-
// assert.Empty(t, model3.ModelConfiguration)
25+
model, found := modelsIndex.GetModelByID("face-detection")
26+
assert.Equal(t, "brick", model.Runner)
27+
require.True(t, found, "face-detection should be found")
28+
assert.Equal(t, "face-detection", model.ID)
29+
assert.Equal(t, "Lightweight-Face-Detection", model.Name)
30+
assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model.ModuleDescription)
31+
assert.Equal(t, []string{"face"}, model.ModelLabels)
32+
assert.Equal(t, "/models/ootb/ei/lw-face-det.eim", model.ModelConfiguration["EI_OBJ_DETECTION_MODEL"])
33+
assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model.Bricks)
34+
assert.Equal(t, "qualcomm-ai-hub", model.Metadata["source"])
35+
assert.Equal(t, "false", model.Metadata["ei-gpu-mode"])
36+
assert.Equal(t, "face-det-lite", model.Metadata["source-model-id"])
37+
assert.Equal(t, "https://aihub.qualcomm.com/models/face_det_lite", model.Metadata["source-model-url"])
5738
})
5839

59-
// Test file not found error
60-
t.Run("FileNotFound", func(t *testing.T) {
61-
nonExistentPath := paths.New("nonexistent")
40+
t.Run("it fails if model-list.yaml does not exist", func(t *testing.T) {
41+
nonExistentPath := paths.New("nonexistent.yaml")
6242
modelsIndex, err := GenerateModelsIndexFromFile(nonExistentPath)
6343
assert.Error(t, err)
6444
assert.Nil(t, modelsIndex)
6545
})
6646

67-
// Test invalid YAML parsing
68-
t.Run("InvalidYAML", func(t *testing.T) {
69-
// Create a temporary invalid YAML file
70-
invalidPath := testdataPath.Join("invalid-models.yaml")
47+
t.Run("it filters models by a single brick", func(t *testing.T) {
48+
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
49+
require.NoError(t, err)
50+
51+
brick1Models := modelsIndex.GetModelsByBrick("arduino:object_detection")
52+
assert.Len(t, brick1Models, 1)
53+
assert.Equal(t, "face-detection", brick1Models[0].ID)
7154

72-
// We expect this to either fail parsing or handle gracefully
73-
// Since the current implementation may be lenient with missing fields
74-
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath.Parent().Join("testdata-invalid"))
75-
if err != nil {
76-
// If it fails, that's expected for invalid files
77-
assert.Error(t, err)
78-
assert.Nil(t, modelsIndex)
79-
}
80-
// Note: Some invalid YAML might still parse successfully depending on the YAML library's behavior
81-
_ = invalidPath // Avoid unused variable warning
55+
brick1Models = modelsIndex.GetModelsByBrick("not-existing-brick")
56+
assert.Nil(t, brick1Models)
8257
})
8358

84-
// Test brick filtering functionality
85-
t.Run("BrickFiltering", func(t *testing.T) {
86-
modelsIndex, err := GenerateModelsIndexFromFile(testdataPath)
59+
t.Run("it filters models by multiple bricks", func(t *testing.T) {
60+
modelsIndex, err := GenerateModelsIndexFromFile(paths.New("testdata"))
8761
require.NoError(t, err)
8862

89-
// Test GetModelsByBrick
90-
brick1Models := modelsIndex.GetModelsByBrick("brick1")
91-
assert.Len(t, brick1Models, 1)
92-
assert.Equal(t, "test_model_1", brick1Models[0].ID)
93-
94-
brick2Models := modelsIndex.GetModelsByBrick("brick2")
63+
brick2Models := modelsIndex.GetModelsByBrick("arduino:video_object_detection")
9564
assert.Len(t, brick2Models, 2)
96-
modelIDs := []string{brick2Models[0].ID, brick2Models[1].ID}
97-
assert.Contains(t, modelIDs, "test_model_1")
98-
assert.Contains(t, modelIDs, "test_model_2")
65+
assert.Equal(t, "face-detection", brick2Models[0].ID)
66+
assert.Equal(t, "yolox-object-detection", brick2Models[1].ID)
9967

100-
// Test GetModelsByBricks
101-
multiModels := modelsIndex.GetModelsByBricks([]string{"brick1", "brick3"})
102-
assert.Len(t, multiModels, 2)
103-
multiModelIDs := []string{multiModels[0].ID, multiModels[1].ID}
104-
assert.Contains(t, multiModelIDs, "test_model_1")
105-
assert.Contains(t, multiModelIDs, "test_model_2")
68+
bricks2Models := modelsIndex.GetModelsByBricks([]string{"arduino:object_detection", "arduino:video_object_detection"})
69+
assert.Len(t, bricks2Models, 2)
70+
assert.Equal(t, "face-detection", bricks2Models[0].ID)
71+
assert.Equal(t, "yolox-object-detection", bricks2Models[1].ID)
10672

107-
// Test non-existent brick
10873
nonExistentModels := modelsIndex.GetModelsByBrick("nonexistent_brick")
10974
assert.Nil(t, nonExistentModels)
11075
})

internal/orchestrator/modelsindex/testdata/invalid-models.yaml

Lines changed: 0 additions & 8 deletions
This file was deleted.

internal/orchestrator/modelsindex/testdata/models-list.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,4 @@ models:
108108
source-model-id: "YOLOX-Nano"
109109
source-model-url: "https://github.com/Megvii-BaseDetection/YOLOX"
110110
bricks:
111-
- arduino:object_detection
112111
- arduino:video_object_detection

0 commit comments

Comments
 (0)