204 lines
6.2 KiB
C++
204 lines
6.2 KiB
C++
#include <Arduino.h>
|
|
#include <Update.h>
|
|
#include <WiFiServer.h>
|
|
|
|
#include <endian.h>
|
|
|
|
#include "UpdateServer.h"
|
|
|
|
UpdateServer::UpdateServer(const std::string &pw)
|
|
: m_cr(pw)
|
|
{}
|
|
|
|
void UpdateServer::start(void)
|
|
{
|
|
xTaskCreate(
|
|
updateTask, /* Task function. */
|
|
"Update Task", /* name of task. */
|
|
10000, /* Stack size of task */
|
|
this, /* parameter of the task */
|
|
2, /* priority of the task */
|
|
NULL); /* Task handle to keep track of created task */
|
|
}
|
|
|
|
static void sendMessage(WiFiClient *client, bool success, const char *message)
|
|
{
|
|
uint32_t len = htobe32(strlen(message) + 1);
|
|
uint8_t status = success ? 1 : 0;
|
|
|
|
client->write(reinterpret_cast<char*>(&len), sizeof(len));
|
|
client->write(status);
|
|
client->write(message);
|
|
}
|
|
|
|
static bool read_n(WiFiClient *client, char *buf, size_t n)
|
|
{
|
|
size_t nread = 0;
|
|
while(client->connected() && (nread < n)) {
|
|
if(client->available()) {
|
|
size_t rcvd = client->readBytes(buf + nread, n - nread);
|
|
nread += rcvd;
|
|
Serial.print("Received ");
|
|
Serial.print(rcvd);
|
|
Serial.print(" bytes (");
|
|
Serial.print(nread);
|
|
Serial.print("/");
|
|
Serial.print(n);
|
|
Serial.println(" total)");
|
|
}
|
|
|
|
delay(1); // allow other tasks to run during data reception
|
|
}
|
|
|
|
return nread == n;
|
|
}
|
|
|
|
void UpdateServer::updateTask(void *arg)
|
|
{
|
|
UpdateServer *obj = reinterpret_cast<UpdateServer*>(arg);
|
|
|
|
WiFiServer server;
|
|
server.begin(31337);
|
|
|
|
while(true) {
|
|
WiFiClient client = server.available();
|
|
|
|
if(client) {
|
|
Serial.println("Update client connected.");
|
|
|
|
// client connected. Send the challenge
|
|
uint32_t len = htobe32(4);
|
|
client.write(reinterpret_cast<char*>(&len), sizeof(len));
|
|
uint32_t nonce = htobe32(obj->m_cr.nonce());
|
|
client.write(reinterpret_cast<char*>(&nonce), sizeof(nonce));
|
|
|
|
// wait for the response
|
|
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
|
Serial.println("Read from update client (response length) failed.");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
// check length of the response
|
|
len = be32toh(len);
|
|
if(len != 64) {
|
|
Serial.println("Invalid length of response.");
|
|
sendMessage(&client, false, "Invalid response length.");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
// read response
|
|
char response[65];
|
|
|
|
if(!read_n(&client, response, 64)) {
|
|
Serial.println("Read from update client (response) failed.");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
response[64] = '\0';
|
|
if(!obj->m_cr.verify(response)) {
|
|
Serial.println("Client failed authentication.");
|
|
sendMessage(&client, false, "Invalid response.");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
Serial.println("Client authenticated.");
|
|
|
|
// successful authentication
|
|
sendMessage(&client, true, "OK");
|
|
|
|
// read length of the actual update data
|
|
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
|
Serial.println("Read from update client (update data length) failed.");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
len = be32toh(len);
|
|
|
|
if(len <= 1) {
|
|
Serial.println("Update message too small (expected >= 2 byte)");
|
|
sendMessage(&client, false, "Update message too small (expected >= 2 byte).");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
len -= 1; // remove the previously read byte for type
|
|
|
|
// read update type
|
|
uint8_t update_type;
|
|
if(!read_n(&client, reinterpret_cast<char*>(&update_type), sizeof(update_type))) {
|
|
Serial.println("Read from update client (update type) failed.");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
// translate to internal value
|
|
int update_command;
|
|
switch(update_type) {
|
|
case UT_FLASH:
|
|
update_command = U_FLASH;
|
|
break;
|
|
case UT_SPIFFS:
|
|
update_command = U_SPIFFS;
|
|
break;
|
|
default:
|
|
Serial.println("Invalid update type.");
|
|
sendMessage(&client, false, "Invalid update type");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
Serial.print("Update size: ");
|
|
Serial.print(len);
|
|
Serial.println(" Byte");
|
|
|
|
if(!Update.begin(len, update_command)) {
|
|
Serial.println("Cannot start update.");
|
|
sendMessage(&client, false, "Update failed.");
|
|
client.stop();
|
|
continue;
|
|
}
|
|
|
|
uint32_t bytes_processed = 0;
|
|
uint8_t buf[1024];
|
|
while(bytes_processed < len) {
|
|
if(!client.connected()) {
|
|
Serial.println("Connection lost, update aborted.");
|
|
Update.abort();
|
|
break;
|
|
}
|
|
|
|
if(client.available()) {
|
|
size_t bytes_read = client.read(buf, 1024);
|
|
Update.write(buf, bytes_read);
|
|
|
|
bytes_processed += bytes_read;
|
|
}
|
|
|
|
delay(1);
|
|
}
|
|
|
|
// all data processed?
|
|
if(Update.end()) {
|
|
// successful update!
|
|
Serial.println("Update successful! Will reboot in 3 seconds.");
|
|
sendMessage(&client, true, "OK");
|
|
delay(3000);
|
|
ESP.restart();
|
|
} else {
|
|
// update failed
|
|
Serial.println("Update failed on final check.");
|
|
sendMessage(&client, false, "Update failed.");
|
|
}
|
|
|
|
client.stop();
|
|
}
|
|
|
|
delay(100);
|
|
}
|
|
}
|