esp32-sk6812/src/UpdateServer.cpp

211 lines
6.6 KiB
C++

#include <Arduino.h>
#include <Update.h>
#include <WiFiServer.h>
#include <endian.h>
#include "UpdateServer.h"
#include "coreids.h"
UpdateServer::UpdateServer(const std::string &pw, SemaphoreHandle_t *ledLockoutMutex)
: m_cr(pw), m_ledLockoutMutex(ledLockoutMutex)
{}
void UpdateServer::start(void)
{
xTaskCreatePinnedToCore(
updateTask, /* Task function. */
"Update Task", /* name of task. */
10000, /* Stack size of task */
this, /* parameter of the task */
2, /* priority of the task */
NULL, /* Task handle to keep track of created task */
CORE_ID_UPDATESERVER);
}
static void sendMessage(WiFiClient *client, bool success, const char *message)
{
uint32_t len = htobe32(strlen(message) + 1);
uint8_t status = success ? 1 : 0;
client->write(reinterpret_cast<char*>(&len), sizeof(len));
client->write(status);
client->write(message);
}
static bool read_n(WiFiClient *client, char *buf, size_t n)
{
size_t nread = 0;
while(client->connected() && (nread < n)) {
if(client->available()) {
size_t rcvd = client->readBytes(buf + nread, n - nread);
nread += rcvd;
Serial.print("Received ");
Serial.print(rcvd);
Serial.print(" bytes (");
Serial.print(nread);
Serial.print("/");
Serial.print(n);
Serial.println(" total)");
}
delay(1); // allow other tasks to run during data reception
}
return nread == n;
}
void UpdateServer::updateTask(void *arg)
{
UpdateServer *obj = reinterpret_cast<UpdateServer*>(arg);
WiFiServer server;
server.begin(31337);
while(true) {
WiFiClient client = server.available();
if(client) {
Serial.println("Update client connected.");
// client connected. Send the challenge
uint32_t len = htobe32(4);
client.write(reinterpret_cast<char*>(&len), sizeof(len));
uint32_t nonce = htobe32(obj->m_cr.nonce());
client.write(reinterpret_cast<char*>(&nonce), sizeof(nonce));
// wait for the response
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
Serial.println("Read from update client (response length) failed.");
client.stop();
continue;
}
// check length of the response
len = be32toh(len);
if(len != 64) {
Serial.println("Invalid length of response.");
sendMessage(&client, false, "Invalid response length.");
client.stop();
continue;
}
// read response
char response[65];
if(!read_n(&client, response, 64)) {
Serial.println("Read from update client (response) failed.");
client.stop();
continue;
}
response[64] = '\0';
if(!obj->m_cr.verify(response)) {
Serial.println("Client failed authentication.");
sendMessage(&client, false, "Invalid response.");
client.stop();
continue;
}
Serial.println("Client authenticated.");
// successful authentication
sendMessage(&client, true, "OK");
// read length of the actual update data
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
Serial.println("Read from update client (update data length) failed.");
client.stop();
continue;
}
len = be32toh(len);
if(len <= 1) {
Serial.println("Update message too small (expected >= 2 byte)");
sendMessage(&client, false, "Update message too small (expected >= 2 byte).");
client.stop();
continue;
}
len -= 1; // remove the previously read byte for type
// read update type
uint8_t update_type;
if(!read_n(&client, reinterpret_cast<char*>(&update_type), sizeof(update_type))) {
Serial.println("Read from update client (update type) failed.");
client.stop();
continue;
}
// translate to internal value
int update_command;
switch(update_type) {
case UT_FLASH:
update_command = U_FLASH;
break;
case UT_SPIFFS:
update_command = U_SPIFFS;
break;
default:
Serial.println("Invalid update type.");
sendMessage(&client, false, "Invalid update type");
client.stop();
continue;
}
Serial.print("Update size: ");
Serial.print(len);
Serial.println(" Byte");
if(!Update.begin(len, update_command)) {
Serial.println("Cannot start update.");
sendMessage(&client, false, "Update failed.");
client.stop();
continue;
}
uint32_t bytes_processed = 0;
uint8_t buf[1024];
while(bytes_processed < len) {
if(!client.connected()) {
Serial.println("Connection lost, update aborted.");
Update.abort();
break;
}
if(client.available()) {
size_t bytes_read = client.read(buf, 1024);
// block LED updates while writing flash
xSemaphoreTake(*obj->m_ledLockoutMutex, portMAX_DELAY);
Update.write(buf, bytes_read);
xSemaphoreGive(*obj->m_ledLockoutMutex);
bytes_processed += bytes_read;
}
delay(1);
}
// all data processed?
if(Update.end()) {
// successful update!
Serial.println("Update successful! Will reboot in 3 seconds.");
sendMessage(&client, true, "OK");
delay(3000);
ESP.restart();
} else {
// update failed
Serial.println("Update failed on final check.");
sendMessage(&client, false, "Update failed.");
}
client.stop();
}
delay(100);
}
}