diff --git a/impl/src/layer2/connection.c b/impl/src/layer2/connection.c index 8b623ea..d6f6f5e 100644 --- a/impl/src/layer2/connection.c +++ b/impl/src/layer2/connection.c @@ -80,19 +80,6 @@ result_t connection_handle_packet(connection_ctx_t *ctx, const uint8_t *buf, siz return OK; } - if(!ham64_is_equal(&header.src_addr, &ctx->peer_addr)) { - char fmt_src_addr[HAM64_FMT_MAX_LEN]; - char fmt_peer_addr[HAM64_FMT_MAX_LEN]; - - ham64_format(&header.src_addr, fmt_src_addr); - ham64_format(&ctx->peer_addr, fmt_peer_addr); - - LOG(LVL_ERR, "Packet has the wrong source address: got %s, expected %s", - fmt_src_addr, fmt_peer_addr); - - return ERR_INVALID_ADDRESS; - } - if(!ham64_is_equal(&header.dst_addr, &ctx->my_addr)) { char fmt_dst_addr[HAM64_FMT_MAX_LEN]; char fmt_my_addr[HAM64_FMT_MAX_LEN]; @@ -106,44 +93,85 @@ result_t connection_handle_packet(connection_ctx_t *ctx, const uint8_t *buf, siz return ERR_INVALID_ADDRESS; } + size_t header_size = layer2_get_encoded_header_size(&header); + + const uint8_t *payload = buf + header_size; + size_t payload_len = packet_size - header_size; + + return connection_handle_packet_prechecked(ctx, &header, payload, payload_len); +} + +result_t connection_handle_packet_prechecked( + connection_ctx_t *ctx, + const layer2_packet_header_t *header, + const uint8_t *payload, size_t payload_len) +{ + // check the connection state + switch(ctx->conn_state) { + case CONN_STATE_UNINITIALIZED: + case CONN_STATE_INITIALIZED: + case CONN_STATE_CLOSED: + LOG(LVL_ERR, "Trying to pass packet to connection in state %u", ctx->conn_state); + return ERR_INVALID_STATE; + + case CONN_STATE_CONNECTING: + case CONN_STATE_ESTABLISHED: + // in these states, packets can be handled + break; + } + + // check if this packet is from our designated peer + if(!ham64_is_equal(&header->src_addr, &ctx->peer_addr)) { + char fmt_src_addr[HAM64_FMT_MAX_LEN]; + char fmt_peer_addr[HAM64_FMT_MAX_LEN]; + + ham64_format(&header->src_addr, fmt_src_addr); + ham64_format(&ctx->peer_addr, fmt_peer_addr); + + LOG(LVL_ERR, "Packet has the wrong source address: got %s, expected %s", + fmt_src_addr, fmt_peer_addr); + + return ERR_INVALID_ADDRESS; + } + LOG(LVL_DEBUG, "Handling %s packet with rx_seq_nr %u, tx_seq_nr %u.", - layer2_msg_type_to_string(header.msg_type), header.rx_seq_nr, header.tx_seq_nr); + layer2_msg_type_to_string(header->msg_type), header->rx_seq_nr, header->tx_seq_nr); - ctx->last_acked_seq = header.rx_seq_nr; + ctx->last_acked_seq = header->rx_seq_nr; - switch(header.msg_type) { + switch(header->msg_type) { case L2_MSG_TYPE_EMPTY: LOG(LVL_DEBUG, "Empty packet: accepted ACK for %u.", ctx->last_acked_seq); // handle the acknowledgement internally - connection_handle_ack(ctx, header.rx_seq_nr, false); + connection_handle_ack(ctx, header->rx_seq_nr, false); return OK; // do not ACK and call back case L2_MSG_TYPE_CONN_MGMT: case L2_MSG_TYPE_CONNECTIONLESS: - LOG(LVL_WARN, "Message type %s is not implemented yet.", layer2_msg_type_to_string(header.msg_type)); + LOG(LVL_WARN, "Message type %s is not implemented yet.", layer2_msg_type_to_string(header->msg_type)); return OK; case L2_MSG_TYPE_DATA: break; default: - LOG(LVL_ERR, "Invalid message type %d.", header.msg_type); + LOG(LVL_ERR, "Invalid message type %d.", header->msg_type); return ERR_INVALID_STATE; } - if(ctx->next_expected_seq != header.tx_seq_nr) { - LOG(LVL_ERR, "Expected sequence number %u, received %u.", ctx->next_expected_seq, header.tx_seq_nr); + if(ctx->next_expected_seq != header->tx_seq_nr) { + LOG(LVL_ERR, "Expected sequence number %u, received %u.", ctx->next_expected_seq, header->tx_seq_nr); return ERR_SEQUENCE; } ctx->next_expected_seq++; ctx->next_expected_seq &= 0xF; - LOG(LVL_INFO, "Received ACK for seq_nr %u in packet seq_nr %u.", header.rx_seq_nr, header.tx_seq_nr); + LOG(LVL_INFO, "Received ACK for seq_nr %u in packet seq_nr %u.", header->rx_seq_nr, header->tx_seq_nr); // handle the acknowledgement internally - connection_handle_ack(ctx, header.rx_seq_nr, true); + connection_handle_ack(ctx, header->rx_seq_nr, true); size_t header_size = layer2_get_encoded_header_size(&header); @@ -267,7 +295,7 @@ size_t connection_encode_next_packet(connection_ctx_t *ctx, uint8_t ack_seq_nr, case CONN_STATE_INITIALIZED: case CONN_STATE_CLOSED: LOG(LVL_ERR, "Trying to encode packet in inactive state %u", ctx->conn_state); - return ERR_INVALID_STATE; + return 0; case CONN_STATE_CONNECTING: case CONN_STATE_ESTABLISHED: @@ -282,30 +310,18 @@ size_t connection_encode_next_packet(connection_ctx_t *ctx, uint8_t ack_seq_nr, return 0; } - unsigned int crc_size = crc_sizeof_key(PAYLOAD_CRC_SCHEME); - - assert(buf_len >= LAYER2_PACKET_HEADER_ENCODED_SIZE_MAX + crc_size + entry->data_len); - layer2_packet_header_t header = entry->header; header.rx_seq_nr = ack_seq_nr; // encode the header LOG(LVL_DEBUG, "Encoding packet with rx_seq_nr %u, tx_seq_nr %u.", header.rx_seq_nr, header.tx_seq_nr); - size_t packet_size = layer2_encode_packet_header(&header, buf); - - // add the payload data - if(entry->data) { - memcpy(buf + packet_size, entry->data, entry->data_len); + size_t packet_size = layer2_encode_packet(&header, entry->data, entry->data_len, buf, buf_len); + if(packet_size == 0) { + LOG(LVL_ERR, "Buffer too small for encoded packet!"); + return 0; } - packet_size += entry->data_len; - - // calculate CRC of everything and append it to the packet - crc_append_key(PAYLOAD_CRC_SCHEME, buf, packet_size); - - packet_size += crc_size; - ctx->next_packet_index++; return packet_size; diff --git a/impl/src/layer2/connection.h b/impl/src/layer2/connection.h index f1a2f7f..8ea5791 100644 --- a/impl/src/layer2/connection.h +++ b/impl/src/layer2/connection.h @@ -71,6 +71,24 @@ void connection_destroy(connection_ctx_t *ctx); */ result_t connection_handle_packet(connection_ctx_t *ctx, const uint8_t *buf, size_t buf_len); +/*!\brief Handle a received packet where the header has already been decoded. + * + * This function assumes that the following basic checks were already done: + * - CRC is correct + * - Header can be decoded + * - Destination address is the local address + * + * \param ctx The connection context. + * \param header Pointer to the decoded header structure. + * \param payload Pointer to the payload data. + * \param payload_len Length of the payload data. + * \returns A result code from the packet handling procedure. + */ +result_t connection_handle_packet_prechecked( + connection_ctx_t *ctx, + const layer2_packet_header_t *header, + const uint8_t *payload, size_t payload_len); + /*!\brief Return the sequence number expected next by our side. */ uint8_t connection_get_next_expected_seq(const connection_ctx_t *ctx);