First version of at TCP-based OTA updater

This commit is contained in:
Thomas Kolb 2019-11-27 00:22:04 +01:00
parent 24ba2242a4
commit 832b30f485
4 changed files with 283 additions and 1 deletions

18
include/UpdateServer.h Normal file
View file

@ -0,0 +1,18 @@
#pragma once
#include <string>
#include "ChallengeResponse.h"
class UpdateServer
{
public:
UpdateServer(const std::string &pw);
void start(void);
private:
static void updateTask(void *arg);
ChallengeResponse m_cr;
};

87
scripts/ota_update.py Executable file
View file

@ -0,0 +1,87 @@
#!/usr/bin/env python3
import sys
import socket
import struct
import hashlib
import os
def readMessage(sock):
data = sock.recv(4)
length, = struct.unpack(">I", data)
data = sock.recv(1)
success, = struct.unpack("B", data)
data = sock.recv(length-1)
message = data.decode('utf-8')
return success, message
# read the salt from the config file
with open("../data/etc/auth", "r") as authFile:
lineno = 0
for line in authFile:
if lineno == 1:
SALT = line.strip()
break
lineno += 1
_, host, port, filename = sys.argv
# read and store the password from the user
pwd = input("Enter password: ")
s = socket.create_connection( (host, int(port)) )
data = s.recv(4)
length, = struct.unpack(">I", data)
if length != 4:
print(f"Unexpected challenge length: {length}")
exit(1)
data = s.recv(4)
challenge, = struct.unpack(">I", data)
print(f"Challenge: {challenge}")
# build response string
responsestr = pwd + ":" + str(challenge) + ":" + SALT
m = hashlib.sha256()
m.update(responsestr.encode('utf-8'))
response = m.hexdigest()
print(f"Response: {response}")
data = struct.pack(">I64s", len(response), response.encode('ascii') )
s.send(data)
success, message = readMessage(s)
print(f"Server message: {message}")
if not success:
print("Failed.")
exit(1)
with open(filename, "rb") as binfile:
filesize = os.stat(filename).st_size
data = struct.pack(">I", filesize)
s.send(data)
sent_bytes = 0
while True:
data = binfile.read(1024)
if not data:
break
s.send(data)
sent_bytes += len(data)
print(f"Sent {sent_bytes} of {filesize} bytes.")
success, message = readMessage(s)
print(f"Server message: {message}")
if not success:
print("Failed.")
exit(1)

171
src/UpdateServer.cpp Normal file
View file

@ -0,0 +1,171 @@
#include <Arduino.h>
#include <Update.h>
#include <WiFiServer.h>
#include <endian.h>
#include "UpdateServer.h"
UpdateServer::UpdateServer(const std::string &pw)
: m_cr(pw)
{}
void UpdateServer::start(void)
{
xTaskCreate(
updateTask, /* Task function. */
"Update Task", /* name of task. */
10000, /* Stack size of task */
this, /* parameter of the task */
2, /* priority of the task */
NULL); /* Task handle to keep track of created task */
}
static void sendMessage(WiFiClient *client, bool success, const char *message)
{
uint32_t len = htobe32(strlen(message) + 1);
uint8_t status = success ? 1 : 0;
client->write(reinterpret_cast<char*>(&len), sizeof(len));
client->write(status);
client->write(message);
}
static bool read_n(WiFiClient *client, char *buf, size_t n)
{
size_t nread = 0;
while(client->connected() && (nread < n)) {
if(client->available()) {
size_t rcvd = client->readBytes(buf + nread, n - nread);
nread += rcvd;
Serial.print("Received ");
Serial.print(rcvd);
Serial.print(" bytes (");
Serial.print(nread);
Serial.print("/");
Serial.print(n);
Serial.println(" total)");
} else {
Serial.print(".");
delay(1);
}
}
return nread == n;
}
void UpdateServer::updateTask(void *arg)
{
UpdateServer *obj = reinterpret_cast<UpdateServer*>(arg);
WiFiServer server;
server.begin(31337);
while(true) {
WiFiClient client = server.available();
if(client) {
Serial.println("Update client connected.");
// client connected. Send the challenge
uint32_t len = htobe32(4);
client.write(reinterpret_cast<char*>(&len), sizeof(len));
uint32_t nonce = htobe32(obj->m_cr.nonce());
client.write(reinterpret_cast<char*>(&nonce), sizeof(nonce));
// wait for the response
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
Serial.println("Read from update client (response length) failed.");
client.stop();
continue;
}
// check length of the response
len = be32toh(len);
if(len != 64) {
Serial.println("Invalid length of response.");
sendMessage(&client, false, "Invalid response length.");
client.stop();
continue;
}
// read response
char response[65];
if(!read_n(&client, response, 64)) {
Serial.println("Read from update client (response) failed.");
client.stop();
continue;
}
response[64] = '\0';
if(!obj->m_cr.verify(response)) {
Serial.println("Client failed authentication.");
sendMessage(&client, false, "Invalid response.");
client.stop();
continue;
}
Serial.println("Client authenticated.");
// successful authentication
sendMessage(&client, true, "OK");
// read length of the actual update data
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
Serial.println("Read from update client (update data length) failed.");
client.stop();
continue;
}
len = be32toh(len);
Serial.print("Update size: ");
Serial.print(len);
Serial.println(" Byte");
if(!Update.begin(len)) {
Serial.println("Cannot start update.");
sendMessage(&client, false, "Update failed.");
client.stop();
continue;
}
uint32_t bytes_processed = 0;
uint8_t buf[1024];
while(bytes_processed < len) {
if(!client.connected()) {
Serial.println("Connection lost, update aborted.");
Update.abort();
break;
}
if(client.available()) {
size_t bytes_read = client.read(buf, 1024);
Update.write(buf, bytes_read);
bytes_processed += bytes_read;
} else {
delay(1);
}
}
// all data processed?
if(Update.end()) {
// successful update!
Serial.println("Update successful! Will reboot in 3 seconds.");
sendMessage(&client, true, "OK");
delay(3000);
ESP.restart();
} else {
// update failed
Serial.println("Update failed on final check.");
sendMessage(&client, false, "Update failed.");
}
client.stop();
}
delay(100);
}
}

View file

@ -8,6 +8,7 @@
#include "WebServer.h" #include "WebServer.h"
#include "Fader.h" #include "Fader.h"
#include "UDPProto.h" #include "UDPProto.h"
#include "UpdateServer.h"
#include "Config.h" #include "Config.h"
#include <esp32_digital_led_lib.h> #include <esp32_digital_led_lib.h>
@ -31,6 +32,7 @@ WiFiMulti wiFiMulti;
Fader ledFader(NUM_STRIPS, NUM_LEDS, 1, FLIP_STRIPS_MASK); Fader ledFader(NUM_STRIPS, NUM_LEDS, 1, FLIP_STRIPS_MASK);
UDPProto udpProto(&ledFader); UDPProto udpProto(&ledFader);
UpdateServer *updateServer;
bool initLEDs() bool initLEDs()
{ {
@ -209,7 +211,7 @@ void setup()
"LED Task", /* name of task. */ "LED Task", /* name of task. */
10000, /* Stack size of task */ 10000, /* Stack size of task */
NULL, /* parameter of the task */ NULL, /* parameter of the task */
2, /* priority of the task */ 3, /* priority of the task */
NULL); /* Task handle to keep track of created task */ NULL); /* Task handle to keep track of created task */
// Connect the WiFi network (or start an AP if that doesn't work) // Connect the WiFi network (or start an AP if that doesn't work)
@ -228,6 +230,10 @@ void setup()
// start the web server // start the web server
WebServer::instance().setFader(&ledFader); WebServer::instance().setFader(&ledFader);
WebServer::instance().start(); WebServer::instance().start();
// start the update server
updateServer = new UpdateServer(Config::instance().getCRPassword());
updateServer->start();
} }
void loop() { void loop() {