diff --git a/include/templates.h b/include/templates.h index 0742c2c..aa80e49 100644 --- a/include/templates.h +++ b/include/templates.h @@ -89,4 +89,13 @@ " " \ FOOTER +#define ERROR_401 \ + HEADER1 \ + " Error 401 - Unauthorized" \ + HEADER2 \ + "

Error 401 - Unauthorized

" \ + "

You are not authorized to see the requested content.

" \ + "

Enter the correct username and password!

" \ + FOOTER + #endif // TEMPLATES_H diff --git a/src/main.c b/src/main.c index b9e2e25..d582d95 100644 --- a/src/main.c +++ b/src/main.c @@ -45,6 +45,8 @@ #define DEFAULT_PORT 8888 char *shareRoot; +char *password; +struct MHD_Response *error401Response; struct MHD_Response *error403Response; struct MHD_Response *error404Response; struct MHD_Response *error500Response; @@ -419,6 +421,36 @@ int key_value_iterator(void *cls, } } +// returns 1 on successful authentication and 0 when auth. failed +static int check_authorization(struct MHD_Connection *connection) { + char *auth_pass; + char *auth_user; + int fail; + + if(password == NULL) { + return 1; + } + + auth_pass = NULL; + auth_user = MHD_basic_auth_get_username_password(connection, &auth_pass); + + fail = (auth_user == NULL) || (0 != strcmp(auth_user, "fileshare")) || (0 != strcmp(auth_pass, password)); + + if (auth_user != NULL) { + free(auth_user); + } + + if (auth_pass != NULL) { + free(auth_pass); + } + + if(fail) { + return 0; + } else { + return 1; + } +} + static int connection_handler(void * cls, struct MHD_Connection *connection, @@ -493,6 +525,16 @@ static int connection_handler(void * cls, return MHD_queue_response(connection, MHD_HTTP_FORBIDDEN, error403Response); } + // check the authentication, if enabled + if(!check_authorization(connection)) { + LOG(LVL_WARN, + "Unauthorized request from %s for %s .", + connstate->clientIP, connstate->localFileName); + + return MHD_queue_basic_auth_fail_response( + connection, "Fileshare " VERSION, error401Response); + } + // check properties of the target file/dir if(stat(connstate->localFileName, &(connstate->targetStat)) == -1) { LOG(LVL_ERR, "Cannot stat %s: %s", @@ -609,13 +651,14 @@ void print_urls(int port) { freeifaddrs(ifaddr); } -int parse_cmdline(int argc, char **argv, int *port, int *enableUpload, char **shareRoot) { +int parse_cmdline(int argc, char **argv, int *port, int *enableUpload, char **shareRoot, char **password) { int c; *enableUpload = 0; *port = DEFAULT_PORT; + *password = NULL; - while ((c = getopt (argc, argv, "up:")) != -1) { + while ((c = getopt (argc, argv, "up:P:")) != -1) { switch (c) { case 'u': *enableUpload = 1; @@ -632,6 +675,10 @@ int parse_cmdline(int argc, char **argv, int *port, int *enableUpload, char **sh return 1; } break; + case 'P': + // a password was given on the command line. + *password = optarg; + break; case '?': if (optopt == 'p') { LOG(LVL_ERR, "Option -%c requires an argument.\n", optopt); @@ -674,14 +721,15 @@ int main(int argc, char ** argv) { #endif // parse command line arguments - if(parse_cmdline(argc, argv, &port, &uploadEnabled, &shareRoot)) { + if(parse_cmdline(argc, argv, &port, &uploadEnabled, &shareRoot, &password)) { LOG(LVL_ERR, "Failed to parse command line!"); LOG(LVL_INFO, "Usage: %s [arguments]

", argv[0]); LOG(LVL_INFO, ""); LOG(LVL_INFO, "Arguments:"); LOG(LVL_INFO, ""); - LOG(LVL_INFO, "\t-u Enable Uploads"); - LOG(LVL_INFO, "\t-p port Change the listening port."); + LOG(LVL_INFO, "\t-u Enable Uploads"); + LOG(LVL_INFO, "\t-p port Change the listening port."); + LOG(LVL_INFO, "\t-P password Optional password for HTTP authentication."); LOG(LVL_INFO, ""); return 1; } @@ -690,6 +738,12 @@ int main(int argc, char ** argv) { LOG(LVL_WARN, "Uploads are enabled. Users can create new files anywhere in %s !", shareRoot); } + if(password != NULL) { + LOG(LVL_INFO, "HTTP Authentication enabled. Clients must log in as user 'fileshare', password '%s'.", password); + } else { + LOG(LVL_WARN, "HTTP Authentication disabled, everybody can download the shared files!"); + } + // check if shareRoot is an existing directory if(stat(shareRoot, &sBuf) == -1) { LOG(LVL_FATAL, "Cannot stat %s: %s", shareRoot, strerror(errno)); @@ -728,6 +782,12 @@ int main(int argc, char ** argv) { init_signal_handlers(); // create the static response for error pages + error401Response = MHD_create_response_from_data( + strlen(ERROR_401), + (void*) ERROR_401, + MHD_NO, + MHD_NO); + error403Response = MHD_create_response_from_data( strlen(ERROR_403), (void*) ERROR_403, @@ -746,6 +806,7 @@ int main(int argc, char ** argv) { MHD_NO, MHD_NO); + MHD_add_response_header(error401Response, MHD_HTTP_HEADER_CONTENT_TYPE, "text/html"); MHD_add_response_header(error403Response, MHD_HTTP_HEADER_CONTENT_TYPE, "text/html"); MHD_add_response_header(error404Response, MHD_HTTP_HEADER_CONTENT_TYPE, "text/html"); MHD_add_response_header(error500Response, MHD_HTTP_HEADER_CONTENT_TYPE, "text/html"); @@ -810,6 +871,7 @@ int main(int argc, char ** argv) { MHD_stop_daemon(d6); + MHD_destroy_response(error401Response); MHD_destroy_response(error403Response); MHD_destroy_response(error404Response); MHD_destroy_response(error500Response);