Files
DisplayFlow/demo/windows_sender/TcpServer.cpp
2025-12-22 14:49:47 +08:00

419 lines
14 KiB
C++

#include "TcpServer.h"
#include "FileTransferProtocol.h"
#include <iostream>
#include <fstream>
#include <filesystem>
#include <ws2tcpip.h>
#pragma comment(lib, "ws2_32.lib")
namespace fs = std::filesystem;
TcpServer::TcpServer() = default;
TcpServer::~TcpServer() {
Stop();
}
bool TcpServer::Start(int port) {
if (running_) return true;
// Initialize Winsock
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
std::cerr << "WSAStartup failed" << std::endl;
return false;
}
listenSocket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (listenSocket_ == INVALID_SOCKET) {
std::cerr << "Socket creation failed" << std::endl;
WSACleanup();
return false;
}
sockaddr_in serverAddr;
serverAddr.sin_family = AF_INET;
serverAddr.sin_addr.s_addr = INADDR_ANY;
serverAddr.sin_port = htons(port);
if (bind(listenSocket_, (sockaddr*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) {
std::cerr << "Bind failed" << std::endl;
closesocket(listenSocket_);
WSACleanup();
return false;
}
if (listen(listenSocket_, 1) == SOCKET_ERROR) {
std::cerr << "Listen failed" << std::endl;
closesocket(listenSocket_);
WSACleanup();
return false;
}
running_ = true;
acceptThread_ = std::thread(&TcpServer::AcceptLoop, this);
std::cout << "TCP Server started on port " << port << std::endl;
return true;
}
void TcpServer::Stop() {
running_ = false;
if (listenSocket_ != INVALID_SOCKET) {
closesocket(listenSocket_);
listenSocket_ = INVALID_SOCKET;
}
if (clientSocket_ != INVALID_SOCKET) {
closesocket(clientSocket_);
clientSocket_ = INVALID_SOCKET;
}
if (acceptThread_.joinable()) acceptThread_.join();
if (clientThread_.joinable()) clientThread_.join();
// Close any open file stream
if (currentFileStream_) {
currentFileStream_->close();
delete currentFileStream_;
currentFileStream_ = nullptr;
}
}
void TcpServer::AcceptLoop() {
while (running_) {
sockaddr_in clientAddr;
int clientAddrLen = sizeof(clientAddr);
SOCKET client = accept(listenSocket_, (sockaddr*)&clientAddr, &clientAddrLen);
if (client == INVALID_SOCKET) {
if (running_) std::cerr << "Accept failed" << std::endl;
break;
}
std::cout << "Client connected for File Transfer" << std::endl;
// Close previous client if any
if (clientSocket_ != INVALID_SOCKET) {
closesocket(clientSocket_);
if (clientThread_.joinable()) clientThread_.join();
}
clientSocket_ = client;
clientThread_ = std::thread(&TcpServer::ClientHandler, this, clientSocket_);
}
}
void TcpServer::ClientHandler(SOCKET clientSocket) {
while (running_) {
CommonHeader header;
if (!ReceiveBytes(clientSocket, &header, sizeof(header))) {
std::cout << "Client disconnected" << std::endl;
break;
}
std::vector<uint8_t> payload(header.payloadSize);
if (header.payloadSize > 0) {
if (!ReceiveBytes(clientSocket, payload.data(), header.payloadSize)) {
break;
}
}
PacketType type = static_cast<PacketType>(header.type);
switch (type) {
case PacketType::FileHeader:
HandleFileHeader(clientSocket, payload);
break;
case PacketType::FileData:
HandleFileData(payload);
break;
case PacketType::FileEnd:
HandleFileEnd();
break;
case PacketType::FolderHeader:
HandleFolderHeader(payload);
break;
case PacketType::DirEntry:
HandleDirEntry(payload);
break;
case PacketType::FileHeaderV2:
HandleFileHeaderV2(payload);
break;
case PacketType::FolderEnd:
break;
default:
break;
}
}
closesocket(clientSocket);
clientSocket_ = INVALID_SOCKET;
}
bool TcpServer::ReceiveBytes(SOCKET sock, void* buffer, int size) {
char* ptr = (char*)buffer;
int remaining = size;
while (remaining > 0) {
int received = recv(sock, ptr, remaining, 0);
if (received <= 0) return false;
ptr += received;
remaining -= received;
}
return true;
}
bool TcpServer::SendBytes(SOCKET sock, const void* buffer, int size) {
const char* ptr = (const char*)buffer;
int remaining = size;
while (remaining > 0) {
int sent = send(sock, ptr, remaining, 0);
if (sent <= 0) return false;
ptr += sent;
remaining -= sent;
}
return true;
}
void TcpServer::HandleFileHeader(SOCKET sock, const std::vector<uint8_t>& payload) {
if (payload.size() < sizeof(FileMetadata)) return;
const FileMetadata* meta = reinterpret_cast<const FileMetadata*>(payload.data());
std::lock_guard<std::mutex> lock(fileMutex_);
// Save to Desktop by default
std::string desktopPath = getenv("USERPROFILE");
desktopPath += "\\Desktop\\";
currentFileName_ = desktopPath + std::string(meta->fileName);
currentFileSize_ = meta->fileSize;
receivedBytes_ = 0;
if (currentFileStream_) {
delete currentFileStream_;
}
currentFileStream_ = new std::ofstream(currentFileName_, std::ios::binary);
std::cout << "Receiving file: " << currentFileName_ << " (" << currentFileSize_ << " bytes)" << std::endl;
}
void TcpServer::HandleFileData(const std::vector<uint8_t>& payload) {
std::lock_guard<std::mutex> lock(fileMutex_);
if (currentFileStream_ && currentFileStream_->is_open()) {
currentFileStream_->write((const char*)payload.data(), payload.size());
receivedBytes_ += payload.size();
// Optional: Progress log
// std::cout << "\rProgress: " << (receivedBytes_ * 100 / currentFileSize_) << "%" << std::flush;
}
}
void TcpServer::HandleFileEnd() {
std::lock_guard<std::mutex> lock(fileMutex_);
if (currentFileStream_) {
currentFileStream_->close();
delete currentFileStream_;
currentFileStream_ = nullptr;
std::cout << "\nFile received successfully!" << std::endl;
if (fileReceivedCallback_) {
fileReceivedCallback_(currentFileName_);
}
}
}
bool TcpServer::SendFile(const std::string& filePath) {
if (clientSocket_ == INVALID_SOCKET) {
std::cerr << "No client connected" << std::endl;
return false;
}
std::ifstream file(filePath, std::ios::binary | std::ios::ate);
if (!file.is_open()) {
std::cerr << "Cannot open file: " << filePath << std::endl;
return false;
}
uint64_t fileSize = file.tellg();
file.seekg(0, std::ios::beg);
fs::path p(filePath);
std::string filename = p.filename().string();
// 1. Send Header
CommonHeader header;
header.type = (uint8_t)PacketType::FileHeader;
header.payloadSize = sizeof(FileMetadata);
FileMetadata meta;
meta.fileSize = fileSize;
strncpy_s(meta.fileName, filename.c_str(), sizeof(meta.fileName) - 1);
if (!SendBytes(clientSocket_, &header, sizeof(header))) return false;
if (!SendBytes(clientSocket_, &meta, sizeof(meta))) return false;
// 2. Send Data
const int CHUNK_SIZE = 64 * 1024; // 64KB
std::vector<char> buffer(CHUNK_SIZE);
uint64_t sent = 0;
while (sent < fileSize) {
file.read(buffer.data(), CHUNK_SIZE);
int bytesRead = (int)file.gcount();
CommonHeader dataHeader;
dataHeader.type = (uint8_t)PacketType::FileData;
dataHeader.payloadSize = bytesRead;
if (!SendBytes(clientSocket_, &dataHeader, sizeof(dataHeader))) return false;
if (!SendBytes(clientSocket_, buffer.data(), bytesRead)) return false;
sent += bytesRead;
std::cout << "\rSending: " << (sent * 100 / fileSize) << "%" << std::flush;
}
std::cout << std::endl;
// 3. Send End
CommonHeader endHeader;
endHeader.type = (uint8_t)PacketType::FileEnd;
endHeader.payloadSize = 0;
SendBytes(clientSocket_, &endHeader, sizeof(endHeader));
std::cout << "File sent successfully" << std::endl;
return true;
}
void TcpServer::SetFileReceivedCallback(FileReceivedCallback cb) {
fileReceivedCallback_ = cb;
}
bool TcpServer::SendFolder(const std::string& folderPath) {
if (clientSocket_ == INVALID_SOCKET) {
std::cerr << "No client connected" << std::endl;
return false;
}
fs::path root(folderPath);
if (!fs::exists(root) || !fs::is_directory(root)) {
std::cerr << "Invalid folder: " << folderPath << std::endl;
return false;
}
std::string rootName = root.filename().string();
{
CommonHeader hdr;
hdr.type = (uint8_t)PacketType::FolderHeader;
hdr.payloadSize = (uint32_t)rootName.size();
if (!SendBytes(clientSocket_, &hdr, sizeof(hdr))) return false;
if (hdr.payloadSize > 0) {
if (!SendBytes(clientSocket_, rootName.data(), (int)rootName.size())) return false;
}
}
auto norm = [](std::string s) {
for (auto& c : s) if (c == '\\') c = '/';
return s;
};
for (auto it = fs::recursive_directory_iterator(root); it != fs::recursive_directory_iterator(); ++it) {
const fs::path p = it->path();
fs::path rel = fs::relative(p, root);
std::string relStr = norm(rel.string());
if (it->is_directory()) {
CommonHeader hdr;
hdr.type = (uint8_t)PacketType::DirEntry;
hdr.payloadSize = (uint32_t)relStr.size();
if (!SendBytes(clientSocket_, &hdr, sizeof(hdr))) return false;
if (hdr.payloadSize > 0) {
if (!SendBytes(clientSocket_, relStr.data(), (int)relStr.size())) return false;
}
} else if (it->is_regular_file()) {
uint64_t fileSize = (uint64_t)fs::file_size(p);
std::vector<uint8_t> header;
uint16_t pathLen = (uint16_t)std::min<size_t>(relStr.size(), 65535);
header.resize(8 + 2 + pathLen);
std::memcpy(header.data(), &fileSize, 8);
std::memcpy(header.data() + 8, &pathLen, 2);
std::memcpy(header.data() + 10, relStr.data(), pathLen);
CommonHeader hdr;
hdr.type = (uint8_t)PacketType::FileHeaderV2;
hdr.payloadSize = (uint32_t)header.size();
if (!SendBytes(clientSocket_, &hdr, sizeof(hdr))) return false;
if (!SendBytes(clientSocket_, header.data(), (int)header.size())) return false;
std::ifstream file(p, std::ios::binary);
if (!file.is_open()) continue;
const int CHUNK_SIZE = 64 * 1024;
std::vector<char> buffer(CHUNK_SIZE);
while (file) {
file.read(buffer.data(), CHUNK_SIZE);
int bytesRead = (int)file.gcount();
if (bytesRead <= 0) break;
CommonHeader dh;
dh.type = (uint8_t)PacketType::FileData;
dh.payloadSize = bytesRead;
if (!SendBytes(clientSocket_, &dh, sizeof(dh))) { file.close(); return false; }
if (!SendBytes(clientSocket_, buffer.data(), bytesRead)) { file.close(); return false; }
}
file.close();
CommonHeader endH;
endH.type = (uint8_t)PacketType::FileEnd;
endH.payloadSize = 0;
if (!SendBytes(clientSocket_, &endH, sizeof(endH))) return false;
}
}
CommonHeader fin;
fin.type = (uint8_t)PacketType::FolderEnd;
fin.payloadSize = 0;
if (!SendBytes(clientSocket_, &fin, sizeof(fin))) return false;
std::cout << "Folder sent successfully" << std::endl;
return true;
}
void TcpServer::HandleFolderHeader(const std::vector<uint8_t>& payload) {
std::string desktopPath = getenv("USERPROFILE");
desktopPath += "\\Desktop\\";
std::string rootName;
if (!payload.empty()) {
rootName.assign(reinterpret_cast<const char*>(payload.data()), payload.size());
size_t pos = rootName.find('\0');
if (pos != std::string::npos) rootName = rootName.substr(0, pos);
} else {
rootName = "AndroidFolder";
}
baseFolderRoot_ = desktopPath + rootName;
try {
fs::create_directories(baseFolderRoot_);
} catch (...) {}
std::cout << "Receiving folder: " << baseFolderRoot_ << std::endl;
}
void TcpServer::HandleDirEntry(const std::vector<uint8_t>& payload) {
if (baseFolderRoot_.empty()) return;
std::string rel;
rel.assign(reinterpret_cast<const char*>(payload.data()), payload.size());
size_t pos = rel.find('\0');
if (pos != std::string::npos) rel = rel.substr(0, pos);
std::string path = baseFolderRoot_ + "\\" + rel;
try {
fs::create_directories(path);
} catch (...) {}
std::cout << "Create dir: " << path << std::endl;
}
void TcpServer::HandleFileHeaderV2(const std::vector<uint8_t>& payload) {
if (payload.size() < sizeof(uint64_t) + sizeof(uint16_t)) return;
const uint8_t* p = payload.data();
uint64_t sz = *reinterpret_cast<const uint64_t*>(p);
p += sizeof(uint64_t);
uint16_t pathLen = *reinterpret_cast<const uint16_t*>(p);
p += sizeof(uint16_t);
if (payload.size() < sizeof(uint64_t) + sizeof(uint16_t) + pathLen) return;
std::string rel(reinterpret_cast<const char*>(p), pathLen);
std::string full = baseFolderRoot_.empty() ? rel : (baseFolderRoot_ + "\\" + rel);
{
std::lock_guard<std::mutex> lock(fileMutex_);
if (currentFileStream_) {
delete currentFileStream_;
currentFileStream_ = nullptr;
}
currentFileName_ = full;
currentFileSize_ = sz;
receivedBytes_ = 0;
fs::create_directories(fs::path(full).parent_path());
currentFileStream_ = new std::ofstream(full, std::ios::binary);
}
std::cout << "Receiving file: " << full << " (" << sz << " bytes)" << std::endl;
}