diff --git a/include/UpdateServer.h b/include/UpdateServer.h index 3b992f6..c94c448 100644 --- a/include/UpdateServer.h +++ b/include/UpdateServer.h @@ -7,6 +7,11 @@ class UpdateServer { public: + enum UpdateTypes { + UT_FLASH = 0, + UT_SPIFFS = 1 + }; + UpdateServer(const std::string &pw); void start(void); @@ -15,4 +20,4 @@ class UpdateServer static void updateTask(void *arg); ChallengeResponse m_cr; -}; \ No newline at end of file +}; diff --git a/scripts/ota_update.py b/scripts/ota_update.py index 200731f..d0cfc88 100755 --- a/scripts/ota_update.py +++ b/scripts/ota_update.py @@ -9,6 +9,10 @@ import time from getpass import getpass +# update types +FLASH = 0 +SPIFFS = 1 + def readMessage(sock): data = sock.recv(4) length, = struct.unpack(">I", data) @@ -28,7 +32,15 @@ with open("../data/etc/auth", "r") as authFile: break lineno += 1 -_, host, port, filename = sys.argv +_, host, port, updatetypestr, filename = sys.argv + +if updatetypestr == "flash": + updatetype = FLASH +elif updatetypestr == "spiffs": + updatetype = SPIFFS +else: + print(f"Invalid update type {updatetypestr}. Valid types are: flash, spiffs") + exit(1) # read and store the password from the user pwd = getpass() @@ -70,8 +82,12 @@ print() # for proper progress display with open(filename, "rb") as binfile: filesize = os.stat(filename).st_size + messagesize = filesize + 1 # for type byte - data = struct.pack(">I", filesize) + data = struct.pack(">I", messagesize) + s.send(data) + + data = struct.pack("B", updatetype) s.send(data) sent_bytes = 0 diff --git a/src/UpdateServer.cpp b/src/UpdateServer.cpp index a2121da..dd6c2a1 100644 --- a/src/UpdateServer.cpp +++ b/src/UpdateServer.cpp @@ -119,11 +119,44 @@ void UpdateServer::updateTask(void *arg) len = be32toh(len); + if(len <= 1) { + Serial.println("Update message too small (expected >= 2 byte)"); + sendMessage(&client, false, "Update message too small (expected >= 2 byte)."); + client.stop(); + continue; + } + + len -= 1; // remove the previously read byte for type + + // read update type + uint8_t update_type; + if(!read_n(&client, reinterpret_cast(&update_type), sizeof(update_type))) { + Serial.println("Read from update client (update type) failed."); + client.stop(); + continue; + } + + // translate to internal value + int update_command; + switch(update_type) { + case UT_FLASH: + update_command = U_FLASH; + break; + case UT_SPIFFS: + update_command = U_SPIFFS; + break; + default: + Serial.println("Invalid update type."); + sendMessage(&client, false, "Invalid update type"); + client.stop(); + continue; + } + Serial.print("Update size: "); Serial.print(len); Serial.println(" Byte"); - if(!Update.begin(len)) { + if(!Update.begin(len, update_command)) { Serial.println("Cannot start update."); sendMessage(&client, false, "Update failed."); client.stop(); @@ -167,4 +200,4 @@ void UpdateServer::updateTask(void *arg) delay(100); } -} \ No newline at end of file +}