First version of at TCP-based OTA updater
This commit is contained in:
parent
24ba2242a4
commit
832b30f485
18
include/UpdateServer.h
Normal file
18
include/UpdateServer.h
Normal file
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ChallengeResponse.h"
|
||||
|
||||
class UpdateServer
|
||||
{
|
||||
public:
|
||||
UpdateServer(const std::string &pw);
|
||||
|
||||
void start(void);
|
||||
|
||||
private:
|
||||
static void updateTask(void *arg);
|
||||
|
||||
ChallengeResponse m_cr;
|
||||
};
|
87
scripts/ota_update.py
Executable file
87
scripts/ota_update.py
Executable file
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import socket
|
||||
import struct
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
def readMessage(sock):
|
||||
data = sock.recv(4)
|
||||
length, = struct.unpack(">I", data)
|
||||
data = sock.recv(1)
|
||||
success, = struct.unpack("B", data)
|
||||
data = sock.recv(length-1)
|
||||
message = data.decode('utf-8')
|
||||
|
||||
return success, message
|
||||
|
||||
# read the salt from the config file
|
||||
with open("../data/etc/auth", "r") as authFile:
|
||||
lineno = 0
|
||||
for line in authFile:
|
||||
if lineno == 1:
|
||||
SALT = line.strip()
|
||||
break
|
||||
lineno += 1
|
||||
|
||||
_, host, port, filename = sys.argv
|
||||
|
||||
# read and store the password from the user
|
||||
pwd = input("Enter password: ")
|
||||
|
||||
s = socket.create_connection( (host, int(port)) )
|
||||
|
||||
data = s.recv(4)
|
||||
length, = struct.unpack(">I", data)
|
||||
|
||||
if length != 4:
|
||||
print(f"Unexpected challenge length: {length}")
|
||||
exit(1)
|
||||
|
||||
data = s.recv(4)
|
||||
challenge, = struct.unpack(">I", data)
|
||||
|
||||
print(f"Challenge: {challenge}")
|
||||
|
||||
# build response string
|
||||
responsestr = pwd + ":" + str(challenge) + ":" + SALT
|
||||
|
||||
m = hashlib.sha256()
|
||||
m.update(responsestr.encode('utf-8'))
|
||||
response = m.hexdigest()
|
||||
|
||||
print(f"Response: {response}")
|
||||
|
||||
data = struct.pack(">I64s", len(response), response.encode('ascii') )
|
||||
s.send(data)
|
||||
|
||||
success, message = readMessage(s)
|
||||
print(f"Server message: {message}")
|
||||
|
||||
if not success:
|
||||
print("Failed.")
|
||||
exit(1)
|
||||
|
||||
with open(filename, "rb") as binfile:
|
||||
filesize = os.stat(filename).st_size
|
||||
|
||||
data = struct.pack(">I", filesize)
|
||||
s.send(data)
|
||||
|
||||
sent_bytes = 0
|
||||
while True:
|
||||
data = binfile.read(1024)
|
||||
if not data:
|
||||
break
|
||||
s.send(data)
|
||||
sent_bytes += len(data)
|
||||
|
||||
print(f"Sent {sent_bytes} of {filesize} bytes.")
|
||||
|
||||
success, message = readMessage(s)
|
||||
print(f"Server message: {message}")
|
||||
|
||||
if not success:
|
||||
print("Failed.")
|
||||
exit(1)
|
171
src/UpdateServer.cpp
Normal file
171
src/UpdateServer.cpp
Normal file
|
@ -0,0 +1,171 @@
|
|||
#include <Arduino.h>
|
||||
#include <Update.h>
|
||||
#include <WiFiServer.h>
|
||||
|
||||
#include <endian.h>
|
||||
|
||||
#include "UpdateServer.h"
|
||||
|
||||
UpdateServer::UpdateServer(const std::string &pw)
|
||||
: m_cr(pw)
|
||||
{}
|
||||
|
||||
void UpdateServer::start(void)
|
||||
{
|
||||
xTaskCreate(
|
||||
updateTask, /* Task function. */
|
||||
"Update Task", /* name of task. */
|
||||
10000, /* Stack size of task */
|
||||
this, /* parameter of the task */
|
||||
2, /* priority of the task */
|
||||
NULL); /* Task handle to keep track of created task */
|
||||
}
|
||||
|
||||
static void sendMessage(WiFiClient *client, bool success, const char *message)
|
||||
{
|
||||
uint32_t len = htobe32(strlen(message) + 1);
|
||||
uint8_t status = success ? 1 : 0;
|
||||
|
||||
client->write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||
client->write(status);
|
||||
client->write(message);
|
||||
}
|
||||
|
||||
static bool read_n(WiFiClient *client, char *buf, size_t n)
|
||||
{
|
||||
size_t nread = 0;
|
||||
while(client->connected() && (nread < n)) {
|
||||
if(client->available()) {
|
||||
size_t rcvd = client->readBytes(buf + nread, n - nread);
|
||||
nread += rcvd;
|
||||
Serial.print("Received ");
|
||||
Serial.print(rcvd);
|
||||
Serial.print(" bytes (");
|
||||
Serial.print(nread);
|
||||
Serial.print("/");
|
||||
Serial.print(n);
|
||||
Serial.println(" total)");
|
||||
} else {
|
||||
Serial.print(".");
|
||||
delay(1);
|
||||
}
|
||||
}
|
||||
|
||||
return nread == n;
|
||||
}
|
||||
|
||||
void UpdateServer::updateTask(void *arg)
|
||||
{
|
||||
UpdateServer *obj = reinterpret_cast<UpdateServer*>(arg);
|
||||
|
||||
WiFiServer server;
|
||||
server.begin(31337);
|
||||
|
||||
while(true) {
|
||||
WiFiClient client = server.available();
|
||||
|
||||
if(client) {
|
||||
Serial.println("Update client connected.");
|
||||
|
||||
// client connected. Send the challenge
|
||||
uint32_t len = htobe32(4);
|
||||
client.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||
uint32_t nonce = htobe32(obj->m_cr.nonce());
|
||||
client.write(reinterpret_cast<char*>(&nonce), sizeof(nonce));
|
||||
|
||||
// wait for the response
|
||||
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
||||
Serial.println("Read from update client (response length) failed.");
|
||||
client.stop();
|
||||
continue;
|
||||
}
|
||||
|
||||
// check length of the response
|
||||
len = be32toh(len);
|
||||
if(len != 64) {
|
||||
Serial.println("Invalid length of response.");
|
||||
sendMessage(&client, false, "Invalid response length.");
|
||||
client.stop();
|
||||
continue;
|
||||
}
|
||||
|
||||
// read response
|
||||
char response[65];
|
||||
|
||||
if(!read_n(&client, response, 64)) {
|
||||
Serial.println("Read from update client (response) failed.");
|
||||
client.stop();
|
||||
continue;
|
||||
}
|
||||
|
||||
response[64] = '\0';
|
||||
if(!obj->m_cr.verify(response)) {
|
||||
Serial.println("Client failed authentication.");
|
||||
sendMessage(&client, false, "Invalid response.");
|
||||
client.stop();
|
||||
continue;
|
||||
}
|
||||
|
||||
Serial.println("Client authenticated.");
|
||||
|
||||
// successful authentication
|
||||
sendMessage(&client, true, "OK");
|
||||
|
||||
// read length of the actual update data
|
||||
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
||||
Serial.println("Read from update client (update data length) failed.");
|
||||
client.stop();
|
||||
continue;
|
||||
}
|
||||
|
||||
len = be32toh(len);
|
||||
|
||||
Serial.print("Update size: ");
|
||||
Serial.print(len);
|
||||
Serial.println(" Byte");
|
||||
|
||||
if(!Update.begin(len)) {
|
||||
Serial.println("Cannot start update.");
|
||||
sendMessage(&client, false, "Update failed.");
|
||||
client.stop();
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t bytes_processed = 0;
|
||||
uint8_t buf[1024];
|
||||
while(bytes_processed < len) {
|
||||
if(!client.connected()) {
|
||||
Serial.println("Connection lost, update aborted.");
|
||||
Update.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
if(client.available()) {
|
||||
size_t bytes_read = client.read(buf, 1024);
|
||||
Update.write(buf, bytes_read);
|
||||
|
||||
bytes_processed += bytes_read;
|
||||
} else {
|
||||
delay(1);
|
||||
}
|
||||
}
|
||||
|
||||
// all data processed?
|
||||
if(Update.end()) {
|
||||
// successful update!
|
||||
Serial.println("Update successful! Will reboot in 3 seconds.");
|
||||
sendMessage(&client, true, "OK");
|
||||
delay(3000);
|
||||
ESP.restart();
|
||||
} else {
|
||||
// update failed
|
||||
Serial.println("Update failed on final check.");
|
||||
sendMessage(&client, false, "Update failed.");
|
||||
}
|
||||
|
||||
client.stop();
|
||||
}
|
||||
|
||||
delay(100);
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@
|
|||
#include "WebServer.h"
|
||||
#include "Fader.h"
|
||||
#include "UDPProto.h"
|
||||
#include "UpdateServer.h"
|
||||
#include "Config.h"
|
||||
|
||||
#include <esp32_digital_led_lib.h>
|
||||
|
@ -31,6 +32,7 @@ WiFiMulti wiFiMulti;
|
|||
|
||||
Fader ledFader(NUM_STRIPS, NUM_LEDS, 1, FLIP_STRIPS_MASK);
|
||||
UDPProto udpProto(&ledFader);
|
||||
UpdateServer *updateServer;
|
||||
|
||||
bool initLEDs()
|
||||
{
|
||||
|
@ -209,7 +211,7 @@ void setup()
|
|||
"LED Task", /* name of task. */
|
||||
10000, /* Stack size of task */
|
||||
NULL, /* parameter of the task */
|
||||
2, /* priority of the task */
|
||||
3, /* priority of the task */
|
||||
NULL); /* Task handle to keep track of created task */
|
||||
|
||||
// Connect the WiFi network (or start an AP if that doesn't work)
|
||||
|
@ -228,6 +230,10 @@ void setup()
|
|||
// start the web server
|
||||
WebServer::instance().setFader(&ledFader);
|
||||
WebServer::instance().start();
|
||||
|
||||
// start the update server
|
||||
updateServer = new UpdateServer(Config::instance().getCRPassword());
|
||||
updateServer->start();
|
||||
}
|
||||
|
||||
void loop() {
|
||||
|
|
Loading…
Reference in a new issue