#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include const int DEFAULT_PORT = 7271; const char* SERVER_CERT_FILE = "server.crt"; const char* SERVER_KEY_FILE = "server.key"; const char* CA_CERT_FILE = "ca.crt"; volatile sig_atomic_t server_running = 1; std::atomic connection_count(0); std::atomic successful_connections(0); std::atomic failed_connections(0); std::mutex log_mutex; void signal_handler(int sig) { std::cout << "\n收到信号 " << sig << ",正在关闭服务器..." << std::endl; server_running = 0; } /** * @brief 获取当前时间戳字符串 * @return std::string 格式化的时间戳 */ std::string get_current_time() { auto now = std::chrono::system_clock::now(); auto time_t = std::chrono::system_clock::to_time_t(now); auto ms = std::chrono::duration_cast( now.time_since_epoch()) % 1000; std::stringstream ss; ss << std::put_time(std::localtime(&time_t), "%Y-%m-%d %H:%M:%S"); ss << '.' << std::setfill('0') << std::setw(3) << ms.count(); return ss.str(); } /** * @brief 线程安全的日志输出 * @param message 日志消息 */ void log_message(const std::string& message) { std::lock_guard lock(log_mutex); std::cout << "[" << get_current_time() << "] " << message << std::endl; } /** * @brief 打印连接统计信息 */ void print_connection_stats() { std::lock_guard lock(log_mutex); std::cout << "\n=== 连接统计信息 ===" << std::endl; std::cout << "总连接次数: " << connection_count.load() << std::endl; std::cout << "成功连接次数: " << successful_connections.load() << std::endl; std::cout << "失败连接次数: " << failed_connections.load() << std::endl; if (connection_count.load() > 0) { double success_rate = (double)successful_connections.load() / connection_count.load() * 100; std::cout << "成功率: " << std::fixed << std::setprecision(2) << success_rate << "%" << std::endl; } std::cout << "===================" << std::endl; } /** * @brief 打印使用说明 */ void print_usage(const char* program_name) { std::cout << "用法: " << program_name << " [选项]" << std::endl; std::cout << "选项:" << std::endl; std::cout << " -p 服务器端口 (默认: " << DEFAULT_PORT << ")" << std::endl; std::cout << " -c 服务器证书文件 (默认: " << SERVER_CERT_FILE << ")" << std::endl; std::cout << " -k 服务器私钥文件 (默认: " << SERVER_KEY_FILE << ")" << std::endl; std::cout << " -a CA证书文件 (默认: " << CA_CERT_FILE << ")" << std::endl; std::cout << " -h 显示此帮助信息" << std::endl; std::cout << std::endl; std::cout << "示例:" << std::endl; std::cout << " " << program_name << " -p 8443" << std::endl; std::cout << " " << program_name << " -p 443 -c my_server.crt -k my_server.key" << std::endl; std::cout << " " << program_name << " -p 8443 -c server.crt -k server.key -a ca.crt" << std::endl; } /** * @brief 解析命令行参数 * @param argc 参数个数 * @param argv 参数数组 * @param port 输出端口 * @param cert_file 输出证书文件 * @param key_file 输出私钥文件 * @param ca_file 输出CA证书文件 * @return bool 解析成功返回true,失败返回false */ bool parse_arguments(int argc, char* argv[], int& port, std::string& cert_file, std::string& key_file, std::string& ca_file) { port = DEFAULT_PORT; cert_file = SERVER_CERT_FILE; key_file = SERVER_KEY_FILE; ca_file = CA_CERT_FILE; int opt; while ((opt = getopt(argc, argv, "p:c:k:a:h")) != -1) { switch (opt) { case 'p': port = std::atoi(optarg); if (port <= 0 || port > 65535) { std::cerr << "错误: 端口号必须在1-65535之间" << std::endl; return false; } break; case 'c': cert_file = optarg; break; case 'k': key_file = optarg; break; case 'a': ca_file = optarg; break; case 'h': print_usage(argv[0]); return false; default: print_usage(argv[0]); return false; } } return true; } void print_openssl_errors() { BIO* bio = BIO_new_fp(stderr, BIO_NOCLOSE); ERR_print_errors(bio); BIO_free(bio); } /** * @brief 分析SSL错误并返回详细错误信息 * @param ssl SSL连接对象 * @param ret SSL函数返回值 * @return std::string 详细错误信息 */ std::string analyze_ssl_error(SSL* ssl, int ret) { std::stringstream error_info; int err = SSL_get_error(ssl, ret); error_info << "SSL错误代码: " << err << " ("; switch (err) { case SSL_ERROR_NONE: error_info << "SSL_ERROR_NONE"; break; case SSL_ERROR_SSL: error_info << "SSL_ERROR_SSL - SSL协议错误"; break; case SSL_ERROR_WANT_READ: error_info << "SSL_ERROR_WANT_READ - 需要读取更多数据"; break; case SSL_ERROR_WANT_WRITE: error_info << "SSL_ERROR_WANT_WRITE - 需要写入更多数据"; break; case SSL_ERROR_WANT_X509_LOOKUP: error_info << "SSL_ERROR_WANT_X509_LOOKUP - 需要X509查找"; break; case SSL_ERROR_SYSCALL: error_info << "SSL_ERROR_SYSCALL - 系统调用错误"; break; case SSL_ERROR_ZERO_RETURN: error_info << "SSL_ERROR_ZERO_RETURN - 连接被关闭"; break; case SSL_ERROR_WANT_CONNECT: error_info << "SSL_ERROR_WANT_CONNECT - 需要连接"; break; case SSL_ERROR_WANT_ACCEPT: error_info << "SSL_ERROR_WANT_ACCEPT - 需要接受连接"; break; default: error_info << "未知SSL错误"; break; } error_info << ")"; // 如果是系统调用错误,获取系统错误信息 if (err == SSL_ERROR_SYSCALL) { int sys_err = errno; if (sys_err != 0) { error_info << " | 系统错误: " << sys_err << " (" << strerror(sys_err) << ")"; } } // 获取SSL状态信息 int state = SSL_get_state(ssl); error_info << " | SSL状态: " << SSL_state_string_long(ssl); return error_info.str(); } void init_openssl() { SSL_load_error_strings(); OpenSSL_add_ssl_algorithms(); } void cleanup_openssl() { #if OPENSSL_VERSION_NUMBER < 0x10100000L EVP_cleanup(); #endif } SSL_CTX* create_context() { const SSL_METHOD* method = TLS_server_method(); SSL_CTX* ctx = SSL_CTX_new(method); if (!ctx) { std::cerr << "无法创建SSL上下文" << std::endl; print_openssl_errors(); exit(EXIT_FAILURE); } SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION); return ctx; } int verify_callback(int preverify_ok, X509_STORE_CTX* ctx) { if (!preverify_ok) { int err = X509_STORE_CTX_get_error(ctx); if (err == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT) { return 1; } } return preverify_ok; } void configure_context(SSL_CTX* ctx, const std::string& cert_file, const std::string& key_file, const std::string& ca_file) { log_message("加载服务器证书: " + cert_file); if (SSL_CTX_use_certificate_file(ctx, cert_file.c_str(), SSL_FILETYPE_PEM) <= 0) { log_message("错误: 无法加载服务器证书 " + cert_file); print_openssl_errors(); exit(EXIT_FAILURE); } log_message("服务器证书加载成功"); log_message("加载服务器私钥: " + key_file); if (SSL_CTX_use_PrivateKey_file(ctx, key_file.c_str(), SSL_FILETYPE_PEM) <= 0) { log_message("错误: 无法加载服务器私钥 " + key_file); print_openssl_errors(); exit(EXIT_FAILURE); } log_message("服务器私钥加载成功"); log_message("验证证书和私钥匹配..."); if (!SSL_CTX_check_private_key(ctx)) { log_message("错误: 证书和私钥不匹配"); exit(EXIT_FAILURE); } log_message("证书和私钥匹配验证成功"); log_message("加载CA证书: " + ca_file); if (SSL_CTX_load_verify_locations(ctx, ca_file.c_str(), nullptr) <= 0) { log_message("错误: 无法加载CA证书 " + ca_file); print_openssl_errors(); exit(EXIT_FAILURE); } log_message("CA证书加载成功"); SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback); SSL_CTX_set_verify_depth(ctx, 4); // 关键优化:简化SSL选项 SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION); SSL_CTX_set_cipher_list(ctx, "ECDHE-RSA-AES128-GCM-SHA256"); log_message("SSL上下文配置完成"); } void graceful_ssl_close(SSL* ssl, int fd) { if (!ssl || fd < 0) return; // 简化的关闭逻辑 SSL_shutdown(ssl); SSL_free(ssl); close(fd); } void handle_connection(SSL* ssl, int client_fd, const char* client_ip, uint16_t client_port) { log_message("处理连接 " + std::string(client_ip) + ":" + std::to_string(client_port)); log_message("协议: " + std::string(SSL_get_version(ssl))); log_message("密码套件: " + std::string(SSL_get_cipher(ssl))); // 验证客户端证书 X509* client_cert = SSL_get_peer_certificate(ssl); if (!client_cert) { log_message("错误: 客户端未提供证书"); failed_connections++; graceful_ssl_close(ssl, client_fd); return; } // 获取证书信息 char* subject = X509_NAME_oneline(X509_get_subject_name(client_cert), 0, 0); char* issuer = X509_NAME_oneline(X509_get_issuer_name(client_cert), 0, 0); log_message("客户端证书主题: " + std::string(subject ? subject : "未知")); log_message("客户端证书颁发者: " + std::string(issuer ? issuer : "未知")); if (subject) OPENSSL_free(subject); if (issuer) OPENSSL_free(issuer); X509_free(client_cert); // 设置TCP_NODELAY int optval = 1; setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval)); // 设置60秒读取超时 struct timeval timeout = {60, 0}; setsockopt(client_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); log_message("等待客户端数据..."); // 读取客户端数据 char buffer[4096]; int bytes = SSL_read(ssl, buffer, sizeof(buffer)); if (bytes > 0) { log_message("收到数据: " + std::to_string(bytes) + "字节"); // 发送响应 const char* response = "HTTP/1.1 200 OK\r\n" "Content-Type: text/plain\r\n" "Connection: close\r\n" "\r\n" "Hello from TLS server!"; SSL_write(ssl, response, strlen(response)); log_message("已发送响应"); successful_connections++; } else if (bytes == 0) { log_message("客户端关闭连接"); } else { int ssl_err = SSL_get_error(ssl, bytes); log_message("SSL读取错误: " + analyze_ssl_error(ssl, bytes)); } graceful_ssl_close(ssl, client_fd); } int main(int argc, char* argv[]) { // 解析命令行参数 int port; std::string cert_file, key_file, ca_file; if (!parse_arguments(argc, argv, port, cert_file, key_file, ca_file)) { return 1; } signal(SIGINT, signal_handler); signal(SIGTERM, signal_handler); log_message("=== TLS服务器启动 ==="); log_message("端口: " + std::to_string(port)); log_message("证书文件: " + cert_file); log_message("私钥文件: " + key_file); log_message("CA证书文件: " + ca_file); log_message("=================="); init_openssl(); SSL_CTX* ctx = create_context(); configure_context(ctx, cert_file, key_file, ca_file); int sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd < 0) { log_message("错误: 创建套接字失败 - " + std::string(strerror(errno))); exit(EXIT_FAILURE); } int optval = 1; setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); sockaddr_in addr{}; addr.sin_family = AF_INET; addr.sin_port = htons(port); addr.sin_addr.s_addr = INADDR_ANY; if (bind(sockfd, (sockaddr*)&addr, sizeof(addr)) < 0) { log_message("错误: 绑定端口失败 - " + std::string(strerror(errno))); exit(EXIT_FAILURE); } if (listen(sockfd, 10) < 0) { log_message("错误: 监听失败 - " + std::string(strerror(errno))); exit(EXIT_FAILURE); } log_message("服务器启动,监听端口 " + std::to_string(port)); log_message("等待客户端连接..."); while (server_running) { sockaddr_in client_addr{}; socklen_t client_len = sizeof(client_addr); int client_fd = accept(sockfd, (sockaddr*)&client_addr, &client_len); if (client_fd < 0) { if (server_running) { log_message("接受连接失败: " + std::string(strerror(errno))); } continue; } // 增加连接计数 connection_count++; int current_conn = connection_count.load(); char client_ip[INET_ADDRSTRLEN]; inet_ntop(AF_INET, &client_addr.sin_addr, client_ip, INET_ADDRSTRLEN); uint16_t client_port = ntohs(client_addr.sin_port); log_message("================"); log_message("新连接 #" + std::to_string(current_conn) + " 来自 " + std::string(client_ip) + ":" + std::to_string(client_port)); SSL* ssl = SSL_new(ctx); SSL_set_fd(ssl, client_fd); log_message("开始SSL握手..."); int ssl_ret = SSL_accept(ssl); if (ssl_ret <= 0) { std::string error_msg = "SSL握手失败 #" + std::to_string(current_conn) + ": " + analyze_ssl_error(ssl, ssl_ret); log_message(error_msg); // 打印详细的OpenSSL错误 log_message("OpenSSL错误详情:"); print_openssl_errors(); failed_connections++; SSL_free(ssl); close(client_fd); continue; } log_message("SSL握手成功 #" + std::to_string(current_conn)); handle_connection(ssl, client_fd, client_ip, client_port); // 每10个连接显示一次统计信息 if (current_conn % 10 == 0) { print_connection_stats(); } } log_message("服务器关闭中..."); print_connection_stats(); close(sockfd); SSL_CTX_free(ctx); cleanup_openssl(); return 0; }