From 832b30f4854bcd5c1cddd8ebd4dfe2b12b51fc8d Mon Sep 17 00:00:00 2001 From: Thomas Kolb Date: Wed, 27 Nov 2019 00:22:04 +0100 Subject: [PATCH] First version of at TCP-based OTA updater --- include/UpdateServer.h | 18 +++++ scripts/ota_update.py | 87 +++++++++++++++++++++ src/UpdateServer.cpp | 171 +++++++++++++++++++++++++++++++++++++++++ src/main.cpp | 8 +- 4 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 include/UpdateServer.h create mode 100755 scripts/ota_update.py create mode 100644 src/UpdateServer.cpp diff --git a/include/UpdateServer.h b/include/UpdateServer.h new file mode 100644 index 0000000..3b992f6 --- /dev/null +++ b/include/UpdateServer.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +#include "ChallengeResponse.h" + +class UpdateServer +{ + public: + UpdateServer(const std::string &pw); + + void start(void); + + private: + static void updateTask(void *arg); + + ChallengeResponse m_cr; +}; \ No newline at end of file diff --git a/scripts/ota_update.py b/scripts/ota_update.py new file mode 100755 index 0000000..4a255ac --- /dev/null +++ b/scripts/ota_update.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import sys +import socket +import struct +import hashlib +import os + +def readMessage(sock): + data = sock.recv(4) + length, = struct.unpack(">I", data) + data = sock.recv(1) + success, = struct.unpack("B", data) + data = sock.recv(length-1) + message = data.decode('utf-8') + + return success, message + +# read the salt from the config file +with open("../data/etc/auth", "r") as authFile: + lineno = 0 + for line in authFile: + if lineno == 1: + SALT = line.strip() + break + lineno += 1 + +_, host, port, filename = sys.argv + +# read and store the password from the user +pwd = input("Enter password: ") + +s = socket.create_connection( (host, int(port)) ) + +data = s.recv(4) +length, = struct.unpack(">I", data) + +if length != 4: + print(f"Unexpected challenge length: {length}") + exit(1) + +data = s.recv(4) +challenge, = struct.unpack(">I", data) + +print(f"Challenge: {challenge}") + +# build response string +responsestr = pwd + ":" + str(challenge) + ":" + SALT + +m = hashlib.sha256() +m.update(responsestr.encode('utf-8')) +response = m.hexdigest() + +print(f"Response: {response}") + +data = struct.pack(">I64s", len(response), response.encode('ascii') ) +s.send(data) + +success, message = readMessage(s) +print(f"Server message: {message}") + +if not success: + print("Failed.") + exit(1) + +with open(filename, "rb") as binfile: + filesize = os.stat(filename).st_size + + data = struct.pack(">I", filesize) + s.send(data) + + sent_bytes = 0 + while True: + data = binfile.read(1024) + if not data: + break + s.send(data) + sent_bytes += len(data) + + print(f"Sent {sent_bytes} of {filesize} bytes.") + +success, message = readMessage(s) +print(f"Server message: {message}") + +if not success: + print("Failed.") + exit(1) \ No newline at end of file diff --git a/src/UpdateServer.cpp b/src/UpdateServer.cpp new file mode 100644 index 0000000..a2ec925 --- /dev/null +++ b/src/UpdateServer.cpp @@ -0,0 +1,171 @@ +#include +#include +#include + +#include + +#include "UpdateServer.h" + +UpdateServer::UpdateServer(const std::string &pw) + : m_cr(pw) +{} + +void UpdateServer::start(void) +{ + xTaskCreate( + updateTask, /* Task function. */ + "Update Task", /* name of task. */ + 10000, /* Stack size of task */ + this, /* parameter of the task */ + 2, /* priority of the task */ + NULL); /* Task handle to keep track of created task */ +} + +static void sendMessage(WiFiClient *client, bool success, const char *message) +{ + uint32_t len = htobe32(strlen(message) + 1); + uint8_t status = success ? 1 : 0; + + client->write(reinterpret_cast(&len), sizeof(len)); + client->write(status); + client->write(message); +} + +static bool read_n(WiFiClient *client, char *buf, size_t n) +{ + size_t nread = 0; + while(client->connected() && (nread < n)) { + if(client->available()) { + size_t rcvd = client->readBytes(buf + nread, n - nread); + nread += rcvd; + Serial.print("Received "); + Serial.print(rcvd); + Serial.print(" bytes ("); + Serial.print(nread); + Serial.print("/"); + Serial.print(n); + Serial.println(" total)"); + } else { + Serial.print("."); + delay(1); + } + } + + return nread == n; +} + +void UpdateServer::updateTask(void *arg) +{ + UpdateServer *obj = reinterpret_cast(arg); + + WiFiServer server; + server.begin(31337); + + while(true) { + WiFiClient client = server.available(); + + if(client) { + Serial.println("Update client connected."); + + // client connected. Send the challenge + uint32_t len = htobe32(4); + client.write(reinterpret_cast(&len), sizeof(len)); + uint32_t nonce = htobe32(obj->m_cr.nonce()); + client.write(reinterpret_cast(&nonce), sizeof(nonce)); + + // wait for the response + if(!read_n(&client, reinterpret_cast(&len), sizeof(len))) { + Serial.println("Read from update client (response length) failed."); + client.stop(); + continue; + } + + // check length of the response + len = be32toh(len); + if(len != 64) { + Serial.println("Invalid length of response."); + sendMessage(&client, false, "Invalid response length."); + client.stop(); + continue; + } + + // read response + char response[65]; + + if(!read_n(&client, response, 64)) { + Serial.println("Read from update client (response) failed."); + client.stop(); + continue; + } + + response[64] = '\0'; + if(!obj->m_cr.verify(response)) { + Serial.println("Client failed authentication."); + sendMessage(&client, false, "Invalid response."); + client.stop(); + continue; + } + + Serial.println("Client authenticated."); + + // successful authentication + sendMessage(&client, true, "OK"); + + // read length of the actual update data + if(!read_n(&client, reinterpret_cast(&len), sizeof(len))) { + Serial.println("Read from update client (update data length) failed."); + client.stop(); + continue; + } + + len = be32toh(len); + + Serial.print("Update size: "); + Serial.print(len); + Serial.println(" Byte"); + + if(!Update.begin(len)) { + Serial.println("Cannot start update."); + sendMessage(&client, false, "Update failed."); + client.stop(); + continue; + } + + uint32_t bytes_processed = 0; + uint8_t buf[1024]; + while(bytes_processed < len) { + if(!client.connected()) { + Serial.println("Connection lost, update aborted."); + Update.abort(); + break; + } + + if(client.available()) { + size_t bytes_read = client.read(buf, 1024); + Update.write(buf, bytes_read); + + bytes_processed += bytes_read; + } else { + delay(1); + } + } + + // all data processed? + if(Update.end()) { + // successful update! + Serial.println("Update successful! Will reboot in 3 seconds."); + sendMessage(&client, true, "OK"); + delay(3000); + ESP.restart(); + } else { + // update failed + Serial.println("Update failed on final check."); + sendMessage(&client, false, "Update failed."); + } + + client.stop(); + } + + delay(100); + } +} \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 1bad3c0..8194f8a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -8,6 +8,7 @@ #include "WebServer.h" #include "Fader.h" #include "UDPProto.h" +#include "UpdateServer.h" #include "Config.h" #include @@ -31,6 +32,7 @@ WiFiMulti wiFiMulti; Fader ledFader(NUM_STRIPS, NUM_LEDS, 1, FLIP_STRIPS_MASK); UDPProto udpProto(&ledFader); +UpdateServer *updateServer; bool initLEDs() { @@ -209,7 +211,7 @@ void setup() "LED Task", /* name of task. */ 10000, /* Stack size of task */ NULL, /* parameter of the task */ - 2, /* priority of the task */ + 3, /* priority of the task */ NULL); /* Task handle to keep track of created task */ // Connect the WiFi network (or start an AP if that doesn't work) @@ -228,6 +230,10 @@ void setup() // start the web server WebServer::instance().setFader(&ledFader); WebServer::instance().start(); + + // start the update server + updateServer = new UpdateServer(Config::instance().getCRPassword()); + updateServer->start(); } void loop() {