Files
TLS/tls_server.cpp
2025-10-09 10:50:30 +08:00

472 lines
15 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <iostream>
#include <string>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <signal.h>
#include <sys/select.h>
#include <vector>
#include <netinet/tcp.h>
#include <ctime>
#include <iomanip>
#include <atomic>
#include <thread>
#include <mutex>
#include <cstring>
#include <unistd.h>
const int DEFAULT_PORT = 7271;
const char* SERVER_CERT_FILE = "server.crt";
const char* SERVER_KEY_FILE = "server.key";
const char* CA_CERT_FILE = "ca.crt";
volatile sig_atomic_t server_running = 1;
std::atomic<int> connection_count(0);
std::atomic<int> successful_connections(0);
std::atomic<int> failed_connections(0);
std::mutex log_mutex;
void signal_handler(int sig) {
std::cout << "\n收到信号 " << sig << ",正在关闭服务器..." << std::endl;
server_running = 0;
}
/**
* @brief 获取当前时间戳字符串
* @return std::string 格式化的时间戳
*/
std::string get_current_time() {
auto now = std::chrono::system_clock::now();
auto time_t = std::chrono::system_clock::to_time_t(now);
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch()) % 1000;
std::stringstream ss;
ss << std::put_time(std::localtime(&time_t), "%Y-%m-%d %H:%M:%S");
ss << '.' << std::setfill('0') << std::setw(3) << ms.count();
return ss.str();
}
/**
* @brief 线程安全的日志输出
* @param message 日志消息
*/
void log_message(const std::string& message) {
std::lock_guard<std::mutex> lock(log_mutex);
std::cout << "[" << get_current_time() << "] " << message << std::endl;
}
/**
* @brief 打印连接统计信息
*/
void print_connection_stats() {
std::lock_guard<std::mutex> lock(log_mutex);
std::cout << "\n=== 连接统计信息 ===" << std::endl;
std::cout << "总连接次数: " << connection_count.load() << std::endl;
std::cout << "成功连接次数: " << successful_connections.load() << std::endl;
std::cout << "失败连接次数: " << failed_connections.load() << std::endl;
if (connection_count.load() > 0) {
double success_rate = (double)successful_connections.load() / connection_count.load() * 100;
std::cout << "成功率: " << std::fixed << std::setprecision(2) << success_rate << "%" << std::endl;
}
std::cout << "===================" << std::endl;
}
/**
* @brief 打印使用说明
*/
void print_usage(const char* program_name) {
std::cout << "用法: " << program_name << " [选项]" << std::endl;
std::cout << "选项:" << std::endl;
std::cout << " -p <port> 服务器端口 (默认: " << DEFAULT_PORT << ")" << std::endl;
std::cout << " -c <cert_file> 服务器证书文件 (默认: " << SERVER_CERT_FILE << ")" << std::endl;
std::cout << " -k <key_file> 服务器私钥文件 (默认: " << SERVER_KEY_FILE << ")" << std::endl;
std::cout << " -a <ca_file> CA证书文件 (默认: " << CA_CERT_FILE << ")" << std::endl;
std::cout << " -h 显示此帮助信息" << std::endl;
std::cout << std::endl;
std::cout << "示例:" << std::endl;
std::cout << " " << program_name << " -p 8443" << std::endl;
std::cout << " " << program_name << " -p 443 -c my_server.crt -k my_server.key" << std::endl;
std::cout << " " << program_name << " -p 8443 -c server.crt -k server.key -a ca.crt" << std::endl;
}
/**
* @brief 解析命令行参数
* @param argc 参数个数
* @param argv 参数数组
* @param port 输出端口
* @param cert_file 输出证书文件
* @param key_file 输出私钥文件
* @param ca_file 输出CA证书文件
* @return bool 解析成功返回true失败返回false
*/
bool parse_arguments(int argc, char* argv[], int& port, std::string& cert_file,
std::string& key_file, std::string& ca_file) {
port = DEFAULT_PORT;
cert_file = SERVER_CERT_FILE;
key_file = SERVER_KEY_FILE;
ca_file = CA_CERT_FILE;
int opt;
while ((opt = getopt(argc, argv, "p:c:k:a:h")) != -1) {
switch (opt) {
case 'p':
port = std::atoi(optarg);
if (port <= 0 || port > 65535) {
std::cerr << "错误: 端口号必须在1-65535之间" << std::endl;
return false;
}
break;
case 'c':
cert_file = optarg;
break;
case 'k':
key_file = optarg;
break;
case 'a':
ca_file = optarg;
break;
case 'h':
print_usage(argv[0]);
return false;
default:
print_usage(argv[0]);
return false;
}
}
return true;
}
void print_openssl_errors() {
BIO* bio = BIO_new_fp(stderr, BIO_NOCLOSE);
ERR_print_errors(bio);
BIO_free(bio);
}
/**
* @brief 分析SSL错误并返回详细错误信息
* @param ssl SSL连接对象
* @param ret SSL函数返回值
* @return std::string 详细错误信息
*/
std::string analyze_ssl_error(SSL* ssl, int ret) {
std::stringstream error_info;
int err = SSL_get_error(ssl, ret);
error_info << "SSL错误代码: " << err << " (";
switch (err) {
case SSL_ERROR_NONE:
error_info << "SSL_ERROR_NONE";
break;
case SSL_ERROR_SSL:
error_info << "SSL_ERROR_SSL - SSL协议错误";
break;
case SSL_ERROR_WANT_READ:
error_info << "SSL_ERROR_WANT_READ - 需要读取更多数据";
break;
case SSL_ERROR_WANT_WRITE:
error_info << "SSL_ERROR_WANT_WRITE - 需要写入更多数据";
break;
case SSL_ERROR_WANT_X509_LOOKUP:
error_info << "SSL_ERROR_WANT_X509_LOOKUP - 需要X509查找";
break;
case SSL_ERROR_SYSCALL:
error_info << "SSL_ERROR_SYSCALL - 系统调用错误";
break;
case SSL_ERROR_ZERO_RETURN:
error_info << "SSL_ERROR_ZERO_RETURN - 连接被关闭";
break;
case SSL_ERROR_WANT_CONNECT:
error_info << "SSL_ERROR_WANT_CONNECT - 需要连接";
break;
case SSL_ERROR_WANT_ACCEPT:
error_info << "SSL_ERROR_WANT_ACCEPT - 需要接受连接";
break;
default:
error_info << "未知SSL错误";
break;
}
error_info << ")";
// 如果是系统调用错误,获取系统错误信息
if (err == SSL_ERROR_SYSCALL) {
int sys_err = errno;
if (sys_err != 0) {
error_info << " | 系统错误: " << sys_err << " (" << strerror(sys_err) << ")";
}
}
// 获取SSL状态信息
int state = SSL_get_state(ssl);
error_info << " | SSL状态: " << SSL_state_string_long(ssl);
return error_info.str();
}
void init_openssl() {
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
}
void cleanup_openssl() {
#if OPENSSL_VERSION_NUMBER < 0x10100000L
EVP_cleanup();
#endif
}
SSL_CTX* create_context() {
const SSL_METHOD* method = TLS_server_method();
SSL_CTX* ctx = SSL_CTX_new(method);
if (!ctx) {
std::cerr << "无法创建SSL上下文" << std::endl;
print_openssl_errors();
exit(EXIT_FAILURE);
}
SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION);
SSL_CTX_set_max_proto_version(ctx, TLS1_3_VERSION);
return ctx;
}
int verify_callback(int preverify_ok, X509_STORE_CTX* ctx) {
if (!preverify_ok) {
int err = X509_STORE_CTX_get_error(ctx);
if (err == X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT) {
return 1;
}
}
return preverify_ok;
}
void configure_context(SSL_CTX* ctx, const std::string& cert_file, const std::string& key_file, const std::string& ca_file) {
log_message("加载服务器证书: " + cert_file);
if (SSL_CTX_use_certificate_file(ctx, cert_file.c_str(), SSL_FILETYPE_PEM) <= 0) {
log_message("错误: 无法加载服务器证书 " + cert_file);
print_openssl_errors();
exit(EXIT_FAILURE);
}
log_message("服务器证书加载成功");
log_message("加载服务器私钥: " + key_file);
if (SSL_CTX_use_PrivateKey_file(ctx, key_file.c_str(), SSL_FILETYPE_PEM) <= 0) {
log_message("错误: 无法加载服务器私钥 " + key_file);
print_openssl_errors();
exit(EXIT_FAILURE);
}
log_message("服务器私钥加载成功");
log_message("验证证书和私钥匹配...");
if (!SSL_CTX_check_private_key(ctx)) {
log_message("错误: 证书和私钥不匹配");
exit(EXIT_FAILURE);
}
log_message("证书和私钥匹配验证成功");
log_message("加载CA证书: " + ca_file);
if (SSL_CTX_load_verify_locations(ctx, ca_file.c_str(), nullptr) <= 0) {
log_message("错误: 无法加载CA证书 " + ca_file);
print_openssl_errors();
exit(EXIT_FAILURE);
}
log_message("CA证书加载成功");
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback);
SSL_CTX_set_verify_depth(ctx, 4);
// 关键优化简化SSL选项
SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION);
SSL_CTX_set_cipher_list(ctx, "ECDHE-RSA-AES128-GCM-SHA256");
log_message("SSL上下文配置完成");
}
void graceful_ssl_close(SSL* ssl, int fd) {
if (!ssl || fd < 0) return;
// 简化的关闭逻辑
SSL_shutdown(ssl);
SSL_free(ssl);
close(fd);
}
void handle_connection(SSL* ssl, int client_fd, const char* client_ip, uint16_t client_port) {
log_message("处理连接 " + std::string(client_ip) + ":" + std::to_string(client_port));
log_message("协议: " + std::string(SSL_get_version(ssl)));
log_message("密码套件: " + std::string(SSL_get_cipher(ssl)));
// 验证客户端证书
X509* client_cert = SSL_get_peer_certificate(ssl);
if (!client_cert) {
log_message("错误: 客户端未提供证书");
failed_connections++;
graceful_ssl_close(ssl, client_fd);
return;
}
// 获取证书信息
char* subject = X509_NAME_oneline(X509_get_subject_name(client_cert), 0, 0);
char* issuer = X509_NAME_oneline(X509_get_issuer_name(client_cert), 0, 0);
log_message("客户端证书主题: " + std::string(subject ? subject : "未知"));
log_message("客户端证书颁发者: " + std::string(issuer ? issuer : "未知"));
if (subject) OPENSSL_free(subject);
if (issuer) OPENSSL_free(issuer);
X509_free(client_cert);
// 设置TCP_NODELAY
int optval = 1;
setsockopt(client_fd, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval));
// 设置60秒读取超时
struct timeval timeout = {60, 0};
setsockopt(client_fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
log_message("等待客户端数据...");
// 读取客户端数据
char buffer[4096];
int bytes = SSL_read(ssl, buffer, sizeof(buffer));
if (bytes > 0) {
log_message("收到数据: " + std::to_string(bytes) + "字节");
// 发送响应
const char* response =
"HTTP/1.1 200 OK\r\n"
"Content-Type: text/plain\r\n"
"Connection: close\r\n"
"\r\n"
"Hello from TLS server!";
SSL_write(ssl, response, strlen(response));
log_message("已发送响应");
successful_connections++;
} else if (bytes == 0) {
log_message("客户端关闭连接");
} else {
int ssl_err = SSL_get_error(ssl, bytes);
log_message("SSL读取错误: " + analyze_ssl_error(ssl, bytes));
}
graceful_ssl_close(ssl, client_fd);
}
int main(int argc, char* argv[]) {
// 解析命令行参数
int port;
std::string cert_file, key_file, ca_file;
if (!parse_arguments(argc, argv, port, cert_file, key_file, ca_file)) {
return 1;
}
signal(SIGINT, signal_handler);
signal(SIGTERM, signal_handler);
log_message("=== TLS服务器启动 ===");
log_message("端口: " + std::to_string(port));
log_message("证书文件: " + cert_file);
log_message("私钥文件: " + key_file);
log_message("CA证书文件: " + ca_file);
log_message("==================");
init_openssl();
SSL_CTX* ctx = create_context();
configure_context(ctx, cert_file, key_file, ca_file);
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
log_message("错误: 创建套接字失败 - " + std::string(strerror(errno)));
exit(EXIT_FAILURE);
}
int optval = 1;
setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr.s_addr = INADDR_ANY;
if (bind(sockfd, (sockaddr*)&addr, sizeof(addr)) < 0) {
log_message("错误: 绑定端口失败 - " + std::string(strerror(errno)));
exit(EXIT_FAILURE);
}
if (listen(sockfd, 10) < 0) {
log_message("错误: 监听失败 - " + std::string(strerror(errno)));
exit(EXIT_FAILURE);
}
log_message("服务器启动,监听端口 " + std::to_string(port));
log_message("等待客户端连接...");
while (server_running) {
sockaddr_in client_addr{};
socklen_t client_len = sizeof(client_addr);
int client_fd = accept(sockfd, (sockaddr*)&client_addr, &client_len);
if (client_fd < 0) {
if (server_running) {
log_message("接受连接失败: " + std::string(strerror(errno)));
}
continue;
}
// 增加连接计数
connection_count++;
int current_conn = connection_count.load();
char client_ip[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &client_addr.sin_addr, client_ip, INET_ADDRSTRLEN);
uint16_t client_port = ntohs(client_addr.sin_port);
log_message("================");
log_message("新连接 #" + std::to_string(current_conn) + " 来自 " +
std::string(client_ip) + ":" + std::to_string(client_port));
SSL* ssl = SSL_new(ctx);
SSL_set_fd(ssl, client_fd);
log_message("开始SSL握手...");
int ssl_ret = SSL_accept(ssl);
if (ssl_ret <= 0) {
std::string error_msg = "SSL握手失败 #" + std::to_string(current_conn) + ": " +
analyze_ssl_error(ssl, ssl_ret);
log_message(error_msg);
// 打印详细的OpenSSL错误
log_message("OpenSSL错误详情:");
print_openssl_errors();
failed_connections++;
SSL_free(ssl);
close(client_fd);
continue;
}
log_message("SSL握手成功 #" + std::to_string(current_conn));
handle_connection(ssl, client_fd, client_ip, client_port);
// 每10个连接显示一次统计信息
if (current_conn % 10 == 0) {
print_connection_stats();
}
}
log_message("服务器关闭中...");
print_connection_stats();
close(sockfd);
SSL_CTX_free(ctx);
cleanup_openssl();
return 0;
}