|
| 1 | +package modelsindex |
| 2 | + |
| 3 | +import ( |
| 4 | + "testing" |
| 5 | + |
| 6 | + "github.com/arduino/go-paths-helper" |
| 7 | + "github.com/stretchr/testify/assert" |
| 8 | + "github.com/stretchr/testify/require" |
| 9 | +) |
| 10 | + |
| 11 | +func TestGenerateModelsIndexFromFile(t *testing.T) { |
| 12 | + testdataPath := paths.New("testdata") |
| 13 | + |
| 14 | + t.Run("Valid Model list", func(t *testing.T) { |
| 15 | + modelsIndex, err := GenerateModelsIndexFromFile(testdataPath) |
| 16 | + require.NoError(t, err) |
| 17 | + require.NotNil(t, modelsIndex) |
| 18 | + |
| 19 | + models := modelsIndex.GetModels() |
| 20 | + assert.Len(t, models, 3, "Expected 3 models to be parsed") |
| 21 | + |
| 22 | + // Test first model |
| 23 | + model1, found := modelsIndex.GetModelByID("face-detection") |
| 24 | + assert.Equal(t, "brick", model1.Runner) |
| 25 | + require.True(t, found, "face-detection should be found") |
| 26 | + assert.Equal(t, "face-detection:", model1.ID) |
| 27 | + assert.Equal(t, "Lightweight-Face-Detection", model1.Name) |
| 28 | + assert.Equal(t, "Face bounding box detection. This model is trained on the WIDER FACE dataset and can detect faces in images.", model1.ModuleDescription) |
| 29 | + assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.L) |
| 30 | + assert.Equal(t, []string{"arduino:object_detection", "arduino:video_object_detection"}, model1.Bricks) |
| 31 | + assert.Equal(t, "1.0.0", model1.Metadata["version"]) |
| 32 | + assert.Equal(t, "Test Author", model1.Metadata["author"]) |
| 33 | + assert.Equal(t, "1000", model1.ModelConfiguration["max_tokens"]) |
| 34 | + assert.Equal(t, "0.7", model1.ModelConfiguration["temperature"]) |
| 35 | + |
| 36 | + // // Test second model |
| 37 | + // model2, found := modelsIndex.GetModelByID("test_model_2") |
| 38 | + // // require.True(t, found, "test_model_2 should be found") |
| 39 | + // // assert.Equal(t, "test_model_2", model2.ID) |
| 40 | + // // assert.Equal(t, "Test Model 2", model2.Name) |
| 41 | + // // assert.Equal(t, "Another test AI model", model2.ModuleDescription) |
| 42 | + // // assert.Equal(t, "another_runner", model2.Runner) |
| 43 | + // // assert.Equal(t, []string{"brick2", "brick3"}, model2.Bricks) |
| 44 | + // // assert.Equal(t, "2.0.0", model2.Metadata["version"]) |
| 45 | + // // assert.Equal(t, "MIT", model2.Metadata["license"]) |
| 46 | + |
| 47 | + // // Test minimal model |
| 48 | + // model3, found := modelsIndex.GetModelByID("minimal_model") |
| 49 | + // require.True(t, found, "minimal_model should be found") |
| 50 | + // assert.Equal(t, "minimal_model", model3.ID) |
| 51 | + // assert.Equal(t, "Minimal Model", model3.Name) |
| 52 | + // assert.Equal(t, "Minimal model with no optional fields", model3.ModuleDescription) |
| 53 | + // assert.Equal(t, "minimal_runner", model3.Runner) |
| 54 | + // assert.Empty(t, model3.Bricks) |
| 55 | + // assert.Empty(t, model3.Metadata) |
| 56 | + // assert.Empty(t, model3.ModelConfiguration) |
| 57 | + }) |
| 58 | + |
| 59 | + // Test file not found error |
| 60 | + t.Run("FileNotFound", func(t *testing.T) { |
| 61 | + nonExistentPath := paths.New("nonexistent") |
| 62 | + modelsIndex, err := GenerateModelsIndexFromFile(nonExistentPath) |
| 63 | + assert.Error(t, err) |
| 64 | + assert.Nil(t, modelsIndex) |
| 65 | + }) |
| 66 | + |
| 67 | + // Test invalid YAML parsing |
| 68 | + t.Run("InvalidYAML", func(t *testing.T) { |
| 69 | + // Create a temporary invalid YAML file |
| 70 | + invalidPath := testdataPath.Join("invalid-models.yaml") |
| 71 | + |
| 72 | + // We expect this to either fail parsing or handle gracefully |
| 73 | + // Since the current implementation may be lenient with missing fields |
| 74 | + modelsIndex, err := GenerateModelsIndexFromFile(testdataPath.Parent().Join("testdata-invalid")) |
| 75 | + if err != nil { |
| 76 | + // If it fails, that's expected for invalid files |
| 77 | + assert.Error(t, err) |
| 78 | + assert.Nil(t, modelsIndex) |
| 79 | + } |
| 80 | + // Note: Some invalid YAML might still parse successfully depending on the YAML library's behavior |
| 81 | + _ = invalidPath // Avoid unused variable warning |
| 82 | + }) |
| 83 | + |
| 84 | + // Test brick filtering functionality |
| 85 | + t.Run("BrickFiltering", func(t *testing.T) { |
| 86 | + modelsIndex, err := GenerateModelsIndexFromFile(testdataPath) |
| 87 | + require.NoError(t, err) |
| 88 | + |
| 89 | + // Test GetModelsByBrick |
| 90 | + brick1Models := modelsIndex.GetModelsByBrick("brick1") |
| 91 | + assert.Len(t, brick1Models, 1) |
| 92 | + assert.Equal(t, "test_model_1", brick1Models[0].ID) |
| 93 | + |
| 94 | + brick2Models := modelsIndex.GetModelsByBrick("brick2") |
| 95 | + assert.Len(t, brick2Models, 2) |
| 96 | + modelIDs := []string{brick2Models[0].ID, brick2Models[1].ID} |
| 97 | + assert.Contains(t, modelIDs, "test_model_1") |
| 98 | + assert.Contains(t, modelIDs, "test_model_2") |
| 99 | + |
| 100 | + // Test GetModelsByBricks |
| 101 | + multiModels := modelsIndex.GetModelsByBricks([]string{"brick1", "brick3"}) |
| 102 | + assert.Len(t, multiModels, 2) |
| 103 | + multiModelIDs := []string{multiModels[0].ID, multiModels[1].ID} |
| 104 | + assert.Contains(t, multiModelIDs, "test_model_1") |
| 105 | + assert.Contains(t, multiModelIDs, "test_model_2") |
| 106 | + |
| 107 | + // Test non-existent brick |
| 108 | + nonExistentModels := modelsIndex.GetModelsByBrick("nonexistent_brick") |
| 109 | + assert.Nil(t, nonExistentModels) |
| 110 | + }) |
| 111 | +} |
0 commit comments