First version of at TCP-based OTA updater
This commit is contained in:
parent
24ba2242a4
commit
832b30f485
18
include/UpdateServer.h
Normal file
18
include/UpdateServer.h
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "ChallengeResponse.h"
|
||||||
|
|
||||||
|
class UpdateServer
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
UpdateServer(const std::string &pw);
|
||||||
|
|
||||||
|
void start(void);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void updateTask(void *arg);
|
||||||
|
|
||||||
|
ChallengeResponse m_cr;
|
||||||
|
};
|
87
scripts/ota_update.py
Executable file
87
scripts/ota_update.py
Executable file
|
@ -0,0 +1,87 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
def readMessage(sock):
|
||||||
|
data = sock.recv(4)
|
||||||
|
length, = struct.unpack(">I", data)
|
||||||
|
data = sock.recv(1)
|
||||||
|
success, = struct.unpack("B", data)
|
||||||
|
data = sock.recv(length-1)
|
||||||
|
message = data.decode('utf-8')
|
||||||
|
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
# read the salt from the config file
|
||||||
|
with open("../data/etc/auth", "r") as authFile:
|
||||||
|
lineno = 0
|
||||||
|
for line in authFile:
|
||||||
|
if lineno == 1:
|
||||||
|
SALT = line.strip()
|
||||||
|
break
|
||||||
|
lineno += 1
|
||||||
|
|
||||||
|
_, host, port, filename = sys.argv
|
||||||
|
|
||||||
|
# read and store the password from the user
|
||||||
|
pwd = input("Enter password: ")
|
||||||
|
|
||||||
|
s = socket.create_connection( (host, int(port)) )
|
||||||
|
|
||||||
|
data = s.recv(4)
|
||||||
|
length, = struct.unpack(">I", data)
|
||||||
|
|
||||||
|
if length != 4:
|
||||||
|
print(f"Unexpected challenge length: {length}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
data = s.recv(4)
|
||||||
|
challenge, = struct.unpack(">I", data)
|
||||||
|
|
||||||
|
print(f"Challenge: {challenge}")
|
||||||
|
|
||||||
|
# build response string
|
||||||
|
responsestr = pwd + ":" + str(challenge) + ":" + SALT
|
||||||
|
|
||||||
|
m = hashlib.sha256()
|
||||||
|
m.update(responsestr.encode('utf-8'))
|
||||||
|
response = m.hexdigest()
|
||||||
|
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
data = struct.pack(">I64s", len(response), response.encode('ascii') )
|
||||||
|
s.send(data)
|
||||||
|
|
||||||
|
success, message = readMessage(s)
|
||||||
|
print(f"Server message: {message}")
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
print("Failed.")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
with open(filename, "rb") as binfile:
|
||||||
|
filesize = os.stat(filename).st_size
|
||||||
|
|
||||||
|
data = struct.pack(">I", filesize)
|
||||||
|
s.send(data)
|
||||||
|
|
||||||
|
sent_bytes = 0
|
||||||
|
while True:
|
||||||
|
data = binfile.read(1024)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
s.send(data)
|
||||||
|
sent_bytes += len(data)
|
||||||
|
|
||||||
|
print(f"Sent {sent_bytes} of {filesize} bytes.")
|
||||||
|
|
||||||
|
success, message = readMessage(s)
|
||||||
|
print(f"Server message: {message}")
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
print("Failed.")
|
||||||
|
exit(1)
|
171
src/UpdateServer.cpp
Normal file
171
src/UpdateServer.cpp
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
#include <Arduino.h>
|
||||||
|
#include <Update.h>
|
||||||
|
#include <WiFiServer.h>
|
||||||
|
|
||||||
|
#include <endian.h>
|
||||||
|
|
||||||
|
#include "UpdateServer.h"
|
||||||
|
|
||||||
|
UpdateServer::UpdateServer(const std::string &pw)
|
||||||
|
: m_cr(pw)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void UpdateServer::start(void)
|
||||||
|
{
|
||||||
|
xTaskCreate(
|
||||||
|
updateTask, /* Task function. */
|
||||||
|
"Update Task", /* name of task. */
|
||||||
|
10000, /* Stack size of task */
|
||||||
|
this, /* parameter of the task */
|
||||||
|
2, /* priority of the task */
|
||||||
|
NULL); /* Task handle to keep track of created task */
|
||||||
|
}
|
||||||
|
|
||||||
|
static void sendMessage(WiFiClient *client, bool success, const char *message)
|
||||||
|
{
|
||||||
|
uint32_t len = htobe32(strlen(message) + 1);
|
||||||
|
uint8_t status = success ? 1 : 0;
|
||||||
|
|
||||||
|
client->write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||||
|
client->write(status);
|
||||||
|
client->write(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool read_n(WiFiClient *client, char *buf, size_t n)
|
||||||
|
{
|
||||||
|
size_t nread = 0;
|
||||||
|
while(client->connected() && (nread < n)) {
|
||||||
|
if(client->available()) {
|
||||||
|
size_t rcvd = client->readBytes(buf + nread, n - nread);
|
||||||
|
nread += rcvd;
|
||||||
|
Serial.print("Received ");
|
||||||
|
Serial.print(rcvd);
|
||||||
|
Serial.print(" bytes (");
|
||||||
|
Serial.print(nread);
|
||||||
|
Serial.print("/");
|
||||||
|
Serial.print(n);
|
||||||
|
Serial.println(" total)");
|
||||||
|
} else {
|
||||||
|
Serial.print(".");
|
||||||
|
delay(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nread == n;
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateServer::updateTask(void *arg)
|
||||||
|
{
|
||||||
|
UpdateServer *obj = reinterpret_cast<UpdateServer*>(arg);
|
||||||
|
|
||||||
|
WiFiServer server;
|
||||||
|
server.begin(31337);
|
||||||
|
|
||||||
|
while(true) {
|
||||||
|
WiFiClient client = server.available();
|
||||||
|
|
||||||
|
if(client) {
|
||||||
|
Serial.println("Update client connected.");
|
||||||
|
|
||||||
|
// client connected. Send the challenge
|
||||||
|
uint32_t len = htobe32(4);
|
||||||
|
client.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||||
|
uint32_t nonce = htobe32(obj->m_cr.nonce());
|
||||||
|
client.write(reinterpret_cast<char*>(&nonce), sizeof(nonce));
|
||||||
|
|
||||||
|
// wait for the response
|
||||||
|
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
||||||
|
Serial.println("Read from update client (response length) failed.");
|
||||||
|
client.stop();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check length of the response
|
||||||
|
len = be32toh(len);
|
||||||
|
if(len != 64) {
|
||||||
|
Serial.println("Invalid length of response.");
|
||||||
|
sendMessage(&client, false, "Invalid response length.");
|
||||||
|
client.stop();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// read response
|
||||||
|
char response[65];
|
||||||
|
|
||||||
|
if(!read_n(&client, response, 64)) {
|
||||||
|
Serial.println("Read from update client (response) failed.");
|
||||||
|
client.stop();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
response[64] = '\0';
|
||||||
|
if(!obj->m_cr.verify(response)) {
|
||||||
|
Serial.println("Client failed authentication.");
|
||||||
|
sendMessage(&client, false, "Invalid response.");
|
||||||
|
client.stop();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Serial.println("Client authenticated.");
|
||||||
|
|
||||||
|
// successful authentication
|
||||||
|
sendMessage(&client, true, "OK");
|
||||||
|
|
||||||
|
// read length of the actual update data
|
||||||
|
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
||||||
|
Serial.println("Read from update client (update data length) failed.");
|
||||||
|
client.stop();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
len = be32toh(len);
|
||||||
|
|
||||||
|
Serial.print("Update size: ");
|
||||||
|
Serial.print(len);
|
||||||
|
Serial.println(" Byte");
|
||||||
|
|
||||||
|
if(!Update.begin(len)) {
|
||||||
|
Serial.println("Cannot start update.");
|
||||||
|
sendMessage(&client, false, "Update failed.");
|
||||||
|
client.stop();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t bytes_processed = 0;
|
||||||
|
uint8_t buf[1024];
|
||||||
|
while(bytes_processed < len) {
|
||||||
|
if(!client.connected()) {
|
||||||
|
Serial.println("Connection lost, update aborted.");
|
||||||
|
Update.abort();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(client.available()) {
|
||||||
|
size_t bytes_read = client.read(buf, 1024);
|
||||||
|
Update.write(buf, bytes_read);
|
||||||
|
|
||||||
|
bytes_processed += bytes_read;
|
||||||
|
} else {
|
||||||
|
delay(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// all data processed?
|
||||||
|
if(Update.end()) {
|
||||||
|
// successful update!
|
||||||
|
Serial.println("Update successful! Will reboot in 3 seconds.");
|
||||||
|
sendMessage(&client, true, "OK");
|
||||||
|
delay(3000);
|
||||||
|
ESP.restart();
|
||||||
|
} else {
|
||||||
|
// update failed
|
||||||
|
Serial.println("Update failed on final check.");
|
||||||
|
sendMessage(&client, false, "Update failed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
client.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
delay(100);
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,7 @@
|
||||||
#include "WebServer.h"
|
#include "WebServer.h"
|
||||||
#include "Fader.h"
|
#include "Fader.h"
|
||||||
#include "UDPProto.h"
|
#include "UDPProto.h"
|
||||||
|
#include "UpdateServer.h"
|
||||||
#include "Config.h"
|
#include "Config.h"
|
||||||
|
|
||||||
#include <esp32_digital_led_lib.h>
|
#include <esp32_digital_led_lib.h>
|
||||||
|
@ -31,6 +32,7 @@ WiFiMulti wiFiMulti;
|
||||||
|
|
||||||
Fader ledFader(NUM_STRIPS, NUM_LEDS, 1, FLIP_STRIPS_MASK);
|
Fader ledFader(NUM_STRIPS, NUM_LEDS, 1, FLIP_STRIPS_MASK);
|
||||||
UDPProto udpProto(&ledFader);
|
UDPProto udpProto(&ledFader);
|
||||||
|
UpdateServer *updateServer;
|
||||||
|
|
||||||
bool initLEDs()
|
bool initLEDs()
|
||||||
{
|
{
|
||||||
|
@ -209,7 +211,7 @@ void setup()
|
||||||
"LED Task", /* name of task. */
|
"LED Task", /* name of task. */
|
||||||
10000, /* Stack size of task */
|
10000, /* Stack size of task */
|
||||||
NULL, /* parameter of the task */
|
NULL, /* parameter of the task */
|
||||||
2, /* priority of the task */
|
3, /* priority of the task */
|
||||||
NULL); /* Task handle to keep track of created task */
|
NULL); /* Task handle to keep track of created task */
|
||||||
|
|
||||||
// Connect the WiFi network (or start an AP if that doesn't work)
|
// Connect the WiFi network (or start an AP if that doesn't work)
|
||||||
|
@ -228,6 +230,10 @@ void setup()
|
||||||
// start the web server
|
// start the web server
|
||||||
WebServer::instance().setFader(&ledFader);
|
WebServer::instance().setFader(&ledFader);
|
||||||
WebServer::instance().start();
|
WebServer::instance().start();
|
||||||
|
|
||||||
|
// start the update server
|
||||||
|
updateServer = new UpdateServer(Config::instance().getCRPassword());
|
||||||
|
updateServer->start();
|
||||||
}
|
}
|
||||||
|
|
||||||
void loop() {
|
void loop() {
|
||||||
|
|
Loading…
Reference in a new issue