diff --git a/app.yaml b/app.yaml index e5fe6e9..fe4f111 100644 --- a/app.yaml +++ b/app.yaml @@ -5,4 +5,5 @@ ports: bricks: - arduino:web_ui - arduino:object_detection + - arduino:image_classification icon: 🐱 diff --git a/python/main.py b/python/main.py index 59513c8..f032f58 100644 --- a/python/main.py +++ b/python/main.py @@ -1,10 +1,12 @@ from arduino.app_utils import App, Bridge from arduino.app_bricks.web_ui import WebUI from arduino.app_bricks.object_detection import ObjectDetection +from arduino.app_bricks.image_classification import ImageClassification import time import base64 object_detection = ObjectDetection() +image_classification = ImageClassification() def on_matrix_draw(_, data): @@ -46,6 +48,28 @@ def on_set_led_rgb(_, data): Bridge.call("set_led_rgb", led, r_digital, g_digital, b_digital) +def on_classify_image(client_id, data): + """Callback function to handle image classification requests.""" + image_data = data.get("image") + if not image_data: + ui.send_message("classification_error", {"error": "No image data"}) + return + start_time = time.time() * 1000 + # TODO: define confidence + results = image_classification.classify(base64.b64decode(image_data)) + diff = time.time() * 1000 - start_time + + if results is None: + ui.send_message("classification_error", {"error": "No results returned"}) + return + + response = { + "classification": results.get("classification", []), + "processing_time": f"{diff:.2f} ms", + } + ui.send_message("classification_result", response) + + def on_detect_objects(client_id, data): """Callback function to handle object detection requests.""" try: @@ -80,6 +104,7 @@ def on_detect_objects(client_id, data): ui.on_message("matrix_draw", on_matrix_draw) ui.on_message("set_led_rgb", on_set_led_rgb) ui.on_message("detect_objects", on_detect_objects) +ui.on_message("classify_image", on_classify_image) def on_modulino_button_pressed(btn): diff --git a/scratch-arduino-extensions/packages/scratch-vm/src/extensions/arduino_image_classification/index.js b/scratch-arduino-extensions/packages/scratch-vm/src/extensions/arduino_image_classification/index.js new file mode 100644 index 0000000..6c60ddf --- /dev/null +++ b/scratch-arduino-extensions/packages/scratch-vm/src/extensions/arduino_image_classification/index.js @@ -0,0 +1,82 @@ +const BlockType = require("../../../../../../scratch-editor/packages/scratch-vm/src/extension-support/block-type"); +const ArgumentType = require( + "../../../../../../scratch-editor/packages/scratch-vm/src/extension-support/argument-type", +); +const Video = require("../../../../../../scratch-editor/packages/scratch-vm/src/io/video"); +const ArduinoUnoQ = require("../ArduinoUnoQ"); + +// TODO add icons +const iconURI = ""; +const menuIconURI = ""; + + +class ArduinoImageClassification { + constructor(runtime) { + this.runtime = runtime; + + this.unoq = new ArduinoUnoQ(); + this.unoq.connect(); + + this.runtime.on("PROJECT_LOADED", () => { + if (!this.runtime.renderer) { + console.log("Renderer is NOT available in runtime."); + return; + } + }); + + this.unoq.on("classification_result", (data) => { + if (!data || !data.classification) { + console.log("No classification classification received."); + return; + } + if (data.classification.length === 0) { + console.log("No objects classified."); + return; + } + // {'classification': [{'class_name': 'neutral', 'confidence': '45.82'}, {'class_name': 'stop', 'confidence': '54.16'}]} + console.log(data.classification); + + }); + } +} + +ArduinoImageClassification.prototype.getInfo = function() { + return { + id: "ArduinoImageClassification", + name: "Arduino Image Classification", + menuIconURI: menuIconURI, + blockIconURI: iconURI, + blocks: [ + { + opcode: "classifyImage", + blockType: BlockType.COMMAND, + text: "classify image", + func: "classifyImage", + arguments: {}, + }, + ], + }; +}; + +ArduinoImageClassification.prototype.classifyImage = function(args) { + if (!this.runtime.ioDevices) { + console.log("No ioDevices available."); + return; + } +this.runtime.ioDevices.video.enableVideo(); + + const canvas = this.runtime.ioDevices.video.getFrame({ + format: Video.FORMAT_CANVAS, + dimensions: [480, 360], // the same as the stage resolution + }); + if (canvas) { + const dataUrl = canvas.toDataURL("image/png"); + const base64Frame = dataUrl.split(",")[1]; + this.unoq.classifyImage(base64Frame); + } else { + console.log("No video frame available for classification."); + } +} + + +module.exports = ArduinoImageClassification; diff --git a/scratch-arduino-extensions/scripts/patch-gui.js b/scratch-arduino-extensions/scripts/patch-gui.js index 5487a7b..5cb793b 100644 --- a/scratch-arduino-extensions/scripts/patch-gui.js +++ b/scratch-arduino-extensions/scripts/patch-gui.js @@ -5,6 +5,7 @@ const extensions = [ { name: "ArduinoBasics", directory: "arduino_basics" }, { name: "ArduinoModulino", directory: "arduino_modulino" }, { name: "ArduinoObjectDetection", directory: "arduino_object_detection" }, + { name: "ArduinoImageClassification", directory: "arduino_image_classification" }, ]; // base dir is the 'scratch-arduino-extensions' folder