l2/connection: refactoring

- split handle_packet() in two functions: the first does basic checks and calls
  the second one. This allows to do the basic checks externally (in the
  multi-connection management module) without duplicating them.
- use the new layer2_encode_packet() function.
This commit is contained in:
Thomas Kolb 2024-11-03 14:35:01 +01:00
parent 16f7ae5242
commit 8572018c4e
2 changed files with 74 additions and 40 deletions

View file

@ -80,19 +80,6 @@ result_t connection_handle_packet(connection_ctx_t *ctx, const uint8_t *buf, siz
return OK; return OK;
} }
if(!ham64_is_equal(&header.src_addr, &ctx->peer_addr)) {
char fmt_src_addr[HAM64_FMT_MAX_LEN];
char fmt_peer_addr[HAM64_FMT_MAX_LEN];
ham64_format(&header.src_addr, fmt_src_addr);
ham64_format(&ctx->peer_addr, fmt_peer_addr);
LOG(LVL_ERR, "Packet has the wrong source address: got %s, expected %s",
fmt_src_addr, fmt_peer_addr);
return ERR_INVALID_ADDRESS;
}
if(!ham64_is_equal(&header.dst_addr, &ctx->my_addr)) { if(!ham64_is_equal(&header.dst_addr, &ctx->my_addr)) {
char fmt_dst_addr[HAM64_FMT_MAX_LEN]; char fmt_dst_addr[HAM64_FMT_MAX_LEN];
char fmt_my_addr[HAM64_FMT_MAX_LEN]; char fmt_my_addr[HAM64_FMT_MAX_LEN];
@ -106,44 +93,85 @@ result_t connection_handle_packet(connection_ctx_t *ctx, const uint8_t *buf, siz
return ERR_INVALID_ADDRESS; return ERR_INVALID_ADDRESS;
} }
size_t header_size = layer2_get_encoded_header_size(&header);
const uint8_t *payload = buf + header_size;
size_t payload_len = packet_size - header_size;
return connection_handle_packet_prechecked(ctx, &header, payload, payload_len);
}
result_t connection_handle_packet_prechecked(
connection_ctx_t *ctx,
const layer2_packet_header_t *header,
const uint8_t *payload, size_t payload_len)
{
// check the connection state
switch(ctx->conn_state) {
case CONN_STATE_UNINITIALIZED:
case CONN_STATE_INITIALIZED:
case CONN_STATE_CLOSED:
LOG(LVL_ERR, "Trying to pass packet to connection in state %u", ctx->conn_state);
return ERR_INVALID_STATE;
case CONN_STATE_CONNECTING:
case CONN_STATE_ESTABLISHED:
// in these states, packets can be handled
break;
}
// check if this packet is from our designated peer
if(!ham64_is_equal(&header->src_addr, &ctx->peer_addr)) {
char fmt_src_addr[HAM64_FMT_MAX_LEN];
char fmt_peer_addr[HAM64_FMT_MAX_LEN];
ham64_format(&header->src_addr, fmt_src_addr);
ham64_format(&ctx->peer_addr, fmt_peer_addr);
LOG(LVL_ERR, "Packet has the wrong source address: got %s, expected %s",
fmt_src_addr, fmt_peer_addr);
return ERR_INVALID_ADDRESS;
}
LOG(LVL_DEBUG, "Handling %s packet with rx_seq_nr %u, tx_seq_nr %u.", LOG(LVL_DEBUG, "Handling %s packet with rx_seq_nr %u, tx_seq_nr %u.",
layer2_msg_type_to_string(header.msg_type), header.rx_seq_nr, header.tx_seq_nr); layer2_msg_type_to_string(header->msg_type), header->rx_seq_nr, header->tx_seq_nr);
ctx->last_acked_seq = header.rx_seq_nr; ctx->last_acked_seq = header->rx_seq_nr;
switch(header.msg_type) { switch(header->msg_type) {
case L2_MSG_TYPE_EMPTY: case L2_MSG_TYPE_EMPTY:
LOG(LVL_DEBUG, "Empty packet: accepted ACK for %u.", ctx->last_acked_seq); LOG(LVL_DEBUG, "Empty packet: accepted ACK for %u.", ctx->last_acked_seq);
// handle the acknowledgement internally // handle the acknowledgement internally
connection_handle_ack(ctx, header.rx_seq_nr, false); connection_handle_ack(ctx, header->rx_seq_nr, false);
return OK; // do not ACK and call back return OK; // do not ACK and call back
case L2_MSG_TYPE_CONN_MGMT: case L2_MSG_TYPE_CONN_MGMT:
case L2_MSG_TYPE_CONNECTIONLESS: case L2_MSG_TYPE_CONNECTIONLESS:
LOG(LVL_WARN, "Message type %s is not implemented yet.", layer2_msg_type_to_string(header.msg_type)); LOG(LVL_WARN, "Message type %s is not implemented yet.", layer2_msg_type_to_string(header->msg_type));
return OK; return OK;
case L2_MSG_TYPE_DATA: case L2_MSG_TYPE_DATA:
break; break;
default: default:
LOG(LVL_ERR, "Invalid message type %d.", header.msg_type); LOG(LVL_ERR, "Invalid message type %d.", header->msg_type);
return ERR_INVALID_STATE; return ERR_INVALID_STATE;
} }
if(ctx->next_expected_seq != header.tx_seq_nr) { if(ctx->next_expected_seq != header->tx_seq_nr) {
LOG(LVL_ERR, "Expected sequence number %u, received %u.", ctx->next_expected_seq, header.tx_seq_nr); LOG(LVL_ERR, "Expected sequence number %u, received %u.", ctx->next_expected_seq, header->tx_seq_nr);
return ERR_SEQUENCE; return ERR_SEQUENCE;
} }
ctx->next_expected_seq++; ctx->next_expected_seq++;
ctx->next_expected_seq &= 0xF; ctx->next_expected_seq &= 0xF;
LOG(LVL_INFO, "Received ACK for seq_nr %u in packet seq_nr %u.", header.rx_seq_nr, header.tx_seq_nr); LOG(LVL_INFO, "Received ACK for seq_nr %u in packet seq_nr %u.", header->rx_seq_nr, header->tx_seq_nr);
// handle the acknowledgement internally // handle the acknowledgement internally
connection_handle_ack(ctx, header.rx_seq_nr, true); connection_handle_ack(ctx, header->rx_seq_nr, true);
size_t header_size = layer2_get_encoded_header_size(&header); size_t header_size = layer2_get_encoded_header_size(&header);
@ -267,7 +295,7 @@ size_t connection_encode_next_packet(connection_ctx_t *ctx, uint8_t ack_seq_nr,
case CONN_STATE_INITIALIZED: case CONN_STATE_INITIALIZED:
case CONN_STATE_CLOSED: case CONN_STATE_CLOSED:
LOG(LVL_ERR, "Trying to encode packet in inactive state %u", ctx->conn_state); LOG(LVL_ERR, "Trying to encode packet in inactive state %u", ctx->conn_state);
return ERR_INVALID_STATE; return 0;
case CONN_STATE_CONNECTING: case CONN_STATE_CONNECTING:
case CONN_STATE_ESTABLISHED: case CONN_STATE_ESTABLISHED:
@ -282,30 +310,18 @@ size_t connection_encode_next_packet(connection_ctx_t *ctx, uint8_t ack_seq_nr,
return 0; return 0;
} }
unsigned int crc_size = crc_sizeof_key(PAYLOAD_CRC_SCHEME);
assert(buf_len >= LAYER2_PACKET_HEADER_ENCODED_SIZE_MAX + crc_size + entry->data_len);
layer2_packet_header_t header = entry->header; layer2_packet_header_t header = entry->header;
header.rx_seq_nr = ack_seq_nr; header.rx_seq_nr = ack_seq_nr;
// encode the header // encode the header
LOG(LVL_DEBUG, "Encoding packet with rx_seq_nr %u, tx_seq_nr %u.", header.rx_seq_nr, header.tx_seq_nr); LOG(LVL_DEBUG, "Encoding packet with rx_seq_nr %u, tx_seq_nr %u.", header.rx_seq_nr, header.tx_seq_nr);
size_t packet_size = layer2_encode_packet_header(&header, buf); size_t packet_size = layer2_encode_packet(&header, entry->data, entry->data_len, buf, buf_len);
if(packet_size == 0) {
// add the payload data LOG(LVL_ERR, "Buffer too small for encoded packet!");
if(entry->data) { return 0;
memcpy(buf + packet_size, entry->data, entry->data_len);
} }
packet_size += entry->data_len;
// calculate CRC of everything and append it to the packet
crc_append_key(PAYLOAD_CRC_SCHEME, buf, packet_size);
packet_size += crc_size;
ctx->next_packet_index++; ctx->next_packet_index++;
return packet_size; return packet_size;

View file

@ -71,6 +71,24 @@ void connection_destroy(connection_ctx_t *ctx);
*/ */
result_t connection_handle_packet(connection_ctx_t *ctx, const uint8_t *buf, size_t buf_len); result_t connection_handle_packet(connection_ctx_t *ctx, const uint8_t *buf, size_t buf_len);
/*!\brief Handle a received packet where the header has already been decoded.
*
* This function assumes that the following basic checks were already done:
* - CRC is correct
* - Header can be decoded
* - Destination address is the local address
*
* \param ctx The connection context.
* \param header Pointer to the decoded header structure.
* \param payload Pointer to the payload data.
* \param payload_len Length of the payload data.
* \returns A result code from the packet handling procedure.
*/
result_t connection_handle_packet_prechecked(
connection_ctx_t *ctx,
const layer2_packet_header_t *header,
const uint8_t *payload, size_t payload_len);
/*!\brief Return the sequence number expected next by our side. /*!\brief Return the sequence number expected next by our side.
*/ */
uint8_t connection_get_next_expected_seq(const connection_ctx_t *ctx); uint8_t connection_get_next_expected_seq(const connection_ctx_t *ctx);