171 lines
5.1 KiB
C++
171 lines
5.1 KiB
C++
|
#include <Arduino.h>
|
||
|
#include <Update.h>
|
||
|
#include <WiFiServer.h>
|
||
|
|
||
|
#include <endian.h>
|
||
|
|
||
|
#include "UpdateServer.h"
|
||
|
|
||
|
UpdateServer::UpdateServer(const std::string &pw)
|
||
|
: m_cr(pw)
|
||
|
{}
|
||
|
|
||
|
void UpdateServer::start(void)
|
||
|
{
|
||
|
xTaskCreate(
|
||
|
updateTask, /* Task function. */
|
||
|
"Update Task", /* name of task. */
|
||
|
10000, /* Stack size of task */
|
||
|
this, /* parameter of the task */
|
||
|
2, /* priority of the task */
|
||
|
NULL); /* Task handle to keep track of created task */
|
||
|
}
|
||
|
|
||
|
static void sendMessage(WiFiClient *client, bool success, const char *message)
|
||
|
{
|
||
|
uint32_t len = htobe32(strlen(message) + 1);
|
||
|
uint8_t status = success ? 1 : 0;
|
||
|
|
||
|
client->write(reinterpret_cast<char*>(&len), sizeof(len));
|
||
|
client->write(status);
|
||
|
client->write(message);
|
||
|
}
|
||
|
|
||
|
static bool read_n(WiFiClient *client, char *buf, size_t n)
|
||
|
{
|
||
|
size_t nread = 0;
|
||
|
while(client->connected() && (nread < n)) {
|
||
|
if(client->available()) {
|
||
|
size_t rcvd = client->readBytes(buf + nread, n - nread);
|
||
|
nread += rcvd;
|
||
|
Serial.print("Received ");
|
||
|
Serial.print(rcvd);
|
||
|
Serial.print(" bytes (");
|
||
|
Serial.print(nread);
|
||
|
Serial.print("/");
|
||
|
Serial.print(n);
|
||
|
Serial.println(" total)");
|
||
|
} else {
|
||
|
Serial.print(".");
|
||
|
delay(1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nread == n;
|
||
|
}
|
||
|
|
||
|
void UpdateServer::updateTask(void *arg)
|
||
|
{
|
||
|
UpdateServer *obj = reinterpret_cast<UpdateServer*>(arg);
|
||
|
|
||
|
WiFiServer server;
|
||
|
server.begin(31337);
|
||
|
|
||
|
while(true) {
|
||
|
WiFiClient client = server.available();
|
||
|
|
||
|
if(client) {
|
||
|
Serial.println("Update client connected.");
|
||
|
|
||
|
// client connected. Send the challenge
|
||
|
uint32_t len = htobe32(4);
|
||
|
client.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||
|
uint32_t nonce = htobe32(obj->m_cr.nonce());
|
||
|
client.write(reinterpret_cast<char*>(&nonce), sizeof(nonce));
|
||
|
|
||
|
// wait for the response
|
||
|
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
||
|
Serial.println("Read from update client (response length) failed.");
|
||
|
client.stop();
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
// check length of the response
|
||
|
len = be32toh(len);
|
||
|
if(len != 64) {
|
||
|
Serial.println("Invalid length of response.");
|
||
|
sendMessage(&client, false, "Invalid response length.");
|
||
|
client.stop();
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
// read response
|
||
|
char response[65];
|
||
|
|
||
|
if(!read_n(&client, response, 64)) {
|
||
|
Serial.println("Read from update client (response) failed.");
|
||
|
client.stop();
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
response[64] = '\0';
|
||
|
if(!obj->m_cr.verify(response)) {
|
||
|
Serial.println("Client failed authentication.");
|
||
|
sendMessage(&client, false, "Invalid response.");
|
||
|
client.stop();
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
Serial.println("Client authenticated.");
|
||
|
|
||
|
// successful authentication
|
||
|
sendMessage(&client, true, "OK");
|
||
|
|
||
|
// read length of the actual update data
|
||
|
if(!read_n(&client, reinterpret_cast<char*>(&len), sizeof(len))) {
|
||
|
Serial.println("Read from update client (update data length) failed.");
|
||
|
client.stop();
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
len = be32toh(len);
|
||
|
|
||
|
Serial.print("Update size: ");
|
||
|
Serial.print(len);
|
||
|
Serial.println(" Byte");
|
||
|
|
||
|
if(!Update.begin(len)) {
|
||
|
Serial.println("Cannot start update.");
|
||
|
sendMessage(&client, false, "Update failed.");
|
||
|
client.stop();
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
uint32_t bytes_processed = 0;
|
||
|
uint8_t buf[1024];
|
||
|
while(bytes_processed < len) {
|
||
|
if(!client.connected()) {
|
||
|
Serial.println("Connection lost, update aborted.");
|
||
|
Update.abort();
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
if(client.available()) {
|
||
|
size_t bytes_read = client.read(buf, 1024);
|
||
|
Update.write(buf, bytes_read);
|
||
|
|
||
|
bytes_processed += bytes_read;
|
||
|
} else {
|
||
|
delay(1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// all data processed?
|
||
|
if(Update.end()) {
|
||
|
// successful update!
|
||
|
Serial.println("Update successful! Will reboot in 3 seconds.");
|
||
|
sendMessage(&client, true, "OK");
|
||
|
delay(3000);
|
||
|
ESP.restart();
|
||
|
} else {
|
||
|
// update failed
|
||
|
Serial.println("Update failed on final check.");
|
||
|
sendMessage(&client, false, "Update failed.");
|
||
|
}
|
||
|
|
||
|
client.stop();
|
||
|
}
|
||
|
|
||
|
delay(100);
|
||
|
}
|
||
|
}
|