diff --git a/CMakeLists.txt b/CMakeLists.txt index 4338a15..117471b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,23 +14,10 @@ project ("Socket") # ソースをこのプロジェクトの実行可能ファイルに追加します。 add_executable (Socket "Socket.cpp" +) - # include - "include/common.h" - "include/Packet.h" - "include/Socket.h" - - # include/Cryptgraphy - "include/Cryptgraphy/AES128.h" - "include/Cryptgraphy/common.h" - "include/Cryptgraphy/ECDSA.h" - "include/Cryptgraphy/ECPoint.h" - "include/Cryptgraphy/KeyManager.h" - "include/Cryptgraphy/ModInt.h" - "include/Cryptgraphy/MultiWordInt.h" - "include/Cryptgraphy/NumberSet.h" - "include/Cryptgraphy/RandomGenerator.h" - "include/Cryptgraphy/SHAKE256.h" +include_directories( + ${PROJECT_SOURCE_DIR}/include ) if (CMAKE_VERSION VERSION_GREATER 3.12) diff --git a/Socket.cpp b/Socket.cpp index c71f978..475d60b 100644 --- a/Socket.cpp +++ b/Socket.cpp @@ -1,14 +1,10 @@ -#include -#include -#include +#include "include/Socket.h" #include -#include "include/Socket.h" - void Server(); void Client(); -AES128::cbytearray<16> sharedkey = {'0', 'x', '7', '4', '0', 'x', '6', '5', '0', 'x', '7', '3', '0', 'x', '7', '4', }; +AES128::cbytearray<16> sharedkey = {'0', 'x', '7', '4', '0', 'x', '6', '5', '0', 'x', '7', '3', '0', 'x', '7', '4',}; struct ClientData { @@ -21,7 +17,7 @@ struct ClientData { Packet::StoreBytes(ret, Name); return ret; } - + Packet::byte_view FromBytes(Packet::byte_view view) { Packet::LoadBytes(view, Level); Packet::LoadBytes(view, Name); @@ -29,165 +25,25 @@ struct ClientData { } }; -struct ContainerInContainer { +int main(int argc, char* argv[]) { - std::vector names; + // arg[1]{ 0 = server, 1 = client } - Packet::bytearray ToBytes() const { - Packet::bytearray ret; - Packet::StoreBytes(ret, names); - return ret; - } + std::vector args; + args.insert(args.end(), argv, argv + argc); - Packet::byte_view FromBytes(Packet::byte_view view) { - Packet::LoadBytes(view, names); - return view; + if (args.size() <= 1) { + return -1; } -}; - -struct ContainerInVariable { - std::vector container; - Packet::bytearray ToBytes() const { - Packet::bytearray ret; - Packet::StoreBytes(ret, container); - return ret; + if (std::stoi(args[1]) == 0) { + Server(); } - - Packet::byte_view FromBytes(Packet::byte_view view) { - Packet::LoadBytes(view, container); - return view; - } -}; -#include "include/Cryptgraphy/KeyManager.h" - -int main(int argc, char* argv[]) { - - //KeyManager Keya; - //KeyManager Keyb; - // - //auto tp = std::chrono::high_resolution_clock::now(); - // - //auto kE = Keya.MakeQKey(); - //auto kF = Keyb.MakeQKey(); - // - //auto Ga = Keya.MakeSharedKey(kF); - //auto Gb = Keyb.MakeSharedKey(kE); - // - //auto ns = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - tp).count(); - // - //bool same = Ga == Gb; - // - //std::cout << (double)ns / 1000 / 1000 / 1000 << "s" << std::endl; - //std::cout << std::boolalpha << "shared key same: " << same << std::endl; - // - //for (auto&& b : Ga) { - // std::cout << std::hex << std::setw(2) << std::setfill('0') << std::right << (int)b; - //} - //std::cout << std::endl; - // - //for (auto&& b : Gb) { - // std::cout << std::hex << std::setw(2) << std::setfill('0') << std::right << (int)b; - //} - //std::cout << std::endl; - - KeyManager key; - std::string message = "I have skill is write low level programing language."; - - auto q = ECDSA::MakePublicKey(key.GetSecretKey()); - - auto v = ECDSA::Sign(key.GetSecretKey(), {message.begin(), message.end()}); - - bool ret = ECDSA::Verify(q, v, {message.begin(), message.end()}); - - std::cout << "message: \"" << message << "\"" << std::endl; - std::cout << "Q: {" << q.x.value.ToString(16) << ", " << q.y.value.ToString(16) << "}" << std::endl; - std::cout << "(r, s)(bytes): "; - for (auto&& b : v) { - std::cout << std::hex << std::setw(2) << std::setfill('0') << std::right << (int)b; + else { + Client(); } - std::cout << std::endl; - - std::cout << std::boolalpha << ret; - - //std::string message = "0123456789abcdef"; - //Cryptgraphy::bytearray data{message.begin(), message.end()}; - // - //auto tp = std::chrono::high_resolution_clock::now(); - // - //auto ret = SHAKE256::HasherN(data, 64); - // - //auto ns = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - tp).count(); - // - //std::cout << (double)ns / 1000 / 1000 << "ms" << std::endl; - //std::cout << std::boolalpha << "hash: "; - // - //for (auto&& c : ret) { - // std::cout << std::hex << std::right << std::setw(2) << std::setfill('0') << (int)c; - //} - - //using int_t = bigint<8>; - //using modint_t = ModInt; - //using projective_t = ECProjective; - // - //modint_t::Factory xmodp = "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"; - //projective_t::Factory projective = WeierstrassParameter( - // xmodp("ffffffff00000001000000000000000000000000fffffffffffffffffffffffc"), - // xmodp("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b") - //); - // - //auto G = projective( - // xmodp("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296"), - // xmodp("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5"), - // xmodp(1) - //); - // - //auto view = [](const std::string& name, const projective_t& p) { - // std::cout << name << ": {" - // << p.x.value.ToString(16) << ", " - // << p.y.value.ToString(16) << ", " - // << p.z.value.ToString(16) << "}" - // << std::endl; - //}; - //auto check = [](const projective_t& p) { - // auto a = p.ToAfinPoint(); - // std::cout << "Check: " << std::boolalpha - // << a.GetParam().CheckPoint(a.x, a.y) - // << std::endl; - //}; - // - //auto d = G.Double(); - // - //view("double", d); - //check(d); - // - //auto add = G.Add(d); - // - //view("add", add); - //check(add); - // - //auto scaler = G.Scaler(xmodp(100)); - // - //view("scaler", scaler); - //check(scaler); - - // arg[1]{ 0 = server, 1 = client } - //std::vector args; - //args.insert(args.end(), argv, argv + argc); - // - //if (args.size() <= 1) { - // return -1; - //} - // - //if (std::stoi(args[1]) == 0) { - // Server(); - //} - //else { - // Client(); - //} - // - //return 0; + return 0; } void Server() { @@ -276,7 +132,7 @@ void Server() { if (oc == c) { continue; } - oc.EncryptionSend(send); + oc.EncryptionSend(Packet(send)); } } } @@ -312,7 +168,7 @@ void Client() { server.CryptEngine.Init(sharedkey); ClientData _data; - + std::cout << "input your Level\n"; std::cin >> _data.Level; std::cout << "input your Name\n"; @@ -331,7 +187,7 @@ void Client() { while (!stopflag) { std::string sendval; std::cin >> sendval; - + if (sendval == "/exit") { stopflag = true; break; @@ -339,8 +195,7 @@ void Client() { std::lock_guard lock(mtx); - Packet pak = Packet(sendval); - server.EncryptionSend(sendval); + server.EncryptionSend(Packet(sendval)); } } }; @@ -350,13 +205,13 @@ void Client() { if (server.LostConnection()) { break; } - + if (server.Available() <= 0) { continue; } auto pak = server.EncryptionRecv(); - + if (!pak) { continue; } diff --git a/example/BytesConvert/BytesConvert.cpp b/example/BytesConvert/BytesConvert.cpp new file mode 100644 index 0000000..55767dd --- /dev/null +++ b/example/BytesConvert/BytesConvert.cpp @@ -0,0 +1,84 @@ +#include "include/Socket.h" + +struct ContainerInContainer { + + std::vector names; + + Packet::bytearray ToBytes() const { + Packet::bytearray ret; + Packet::StoreBytes(ret, names); + return ret; + } + + Packet::byte_view FromBytes(Packet::byte_view view) { + Packet::LoadBytes(view, names); + return view; + } +}; + +struct ContainerInVariable { + std::vector container; + + Packet::bytearray ToBytes() const { + Packet::bytearray ret; + Packet::StoreBytes(ret, container); + return ret; + } + + Packet::byte_view FromBytes(Packet::byte_view view) { + Packet::LoadBytes(view, container); + return view; + } +}; + +int main(int argc, char* argv[]) { + + ContainerInVariable data{}; + ContainerInContainer cic{}; + + std::string str = "test"; + + cic.names.push_back(str); str += "t"; + cic.names.push_back(str); str += "t"; + cic.names.push_back(str); str += "t"; + cic.names.push_back(str); str += "t"; + + data.container.push_back(cic); + + str = "test2"; + + cic.names.push_back(str); str += "b"; + cic.names.push_back(str); str += "b"; + cic.names.push_back(str); str += "b"; + cic.names.push_back(str); str += "b"; + + data.container.push_back(cic); + + str = "magic"; + + cic.names.push_back(str); str += "m"; + cic.names.push_back(str); str += "m"; + cic.names.push_back(str); str += "m"; + cic.names.push_back(str); str += "m"; + + data.container.push_back(cic); + + str = "test"; + + cic.names.push_back(str); str += "z"; + cic.names.push_back(str); str += "z"; + cic.names.push_back(str); str += "z"; + cic.names.push_back(str); str += "z"; + + data.container.push_back(cic); + + Packet pak = Packet(data); + + auto& buf = pak.GetBuffer(); + + for (auto&& c : buf) { + std::cout << std::uppercase << std::setfill('0') << std::setw(2) << std::hex << std::right << static_cast(c); + } + + return 0; +} \ No newline at end of file diff --git a/example/EC_Signature/EC_Signature.cpp b/example/EC_Signature/EC_Signature.cpp new file mode 100644 index 0000000..d57de07 --- /dev/null +++ b/example/EC_Signature/EC_Signature.cpp @@ -0,0 +1,27 @@ +#include "include/Socket.h" + +int main(int argc, char* argv[]) { + + KeyManager key; + std::string message = "I have skill is write low level programing language."; + + auto q = ECDSA::MakePublicKey(key.GetSecretKey()); + + auto v = ECDSA::Sign(key.GetSecretKey(), {message.begin(), message.end()}); + + bool ret = ECDSA::Verify(q, v, {message.begin(), message.end()}); + + std::cout << "message: \"" << message << "\"" << std::endl; + std::cout << "Q: {" << q.x.value.ToString(16) << ", " << q.y.value.ToString(16) << "}" << std::endl; + std::cout << "(r, s)(bytes): "; + for (auto&& b : v) { + std::cout << std::hex << std::setw(2) << std::setfill('0') << std::right << (int)b; + } + std::cout << std::endl; + + std::cout << std::boolalpha << ret; + + assert(ret); + + return 0; +} \ No newline at end of file diff --git a/example/KeyExchange/KeyExchange.cpp b/example/KeyExchange/KeyExchange.cpp new file mode 100644 index 0000000..e4c8f2f --- /dev/null +++ b/example/KeyExchange/KeyExchange.cpp @@ -0,0 +1,36 @@ +#include "include/Cryptgraphy/KeyManager.h" + +int main(int argc, char* argv[]) { + + KeyManager Keya; + KeyManager Keyb; + + auto tp = std::chrono::high_resolution_clock::now(); + + auto kE = Keya.MakeQKey(); + auto kF = Keyb.MakeQKey(); + + auto Ga = Keya.MakeSharedKey(kF); + auto Gb = Keyb.MakeSharedKey(kE); + + auto ns = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - tp).count(); + + bool same = Ga == Gb; + + std::cout << (double)ns / 1000 / 1000 / 1000 << "s" << std::endl; + std::cout << std::boolalpha << "shared key same: " << same << std::endl; + + for (auto&& b : Ga) { + std::cout << std::hex << std::setw(2) << std::setfill('0') << std::right << (int)b; + } + std::cout << std::endl; + + for (auto&& b : Gb) { + std::cout << std::hex << std::setw(2) << std::setfill('0') << std::right << (int)b; + } + std::cout << std::endl; + + assert(same); + + return 0; +} \ No newline at end of file diff --git a/example/Network/Network.cpp b/example/Network/Network.cpp new file mode 100644 index 0000000..cd83118 --- /dev/null +++ b/example/Network/Network.cpp @@ -0,0 +1,228 @@ +#include "include/Socket.h" +#include + +void Server(); +void Client(); + +AES128::cbytearray<16> sharedkey = {'0', 'x', '7', '4', '0', 'x', '6', '5', '0', 'x', '7', '3', '0', 'x', '7', '4',}; + +struct ClientData { + + int Level = 0; + std::string Name = "NoName"; + + Packet::bytearray ToBytes() const { + Packet::bytearray ret; + Packet::StoreBytes(ret, Level); + Packet::StoreBytes(ret, Name); + return ret; + } + + Packet::byte_view FromBytes(Packet::byte_view view) { + Packet::LoadBytes(view, Level); + Packet::LoadBytes(view, Name); + return view; + } +}; + +int main(int argc, char* argv[]) { + + // arg[1]{ 0 = server, 1 = client } + + std::vector args; + args.insert(args.end(), argv, argv + argc); + + if (args.size() <= 1) { + return -1; + } + + if (std::stoi(args[1]) == 0) { + Server(); + } + else { + Client(); + } + + return 0; +} + +void Server() { + + TCPServer server(8080); + + std::map> clients; + std::vector> joinqueue; + std::deque lostqueue; + + while (true) { + auto sock = server.Accept(); + + if (sock) { + bool emptyfound = false; + for (auto&& state : joinqueue) { + if (!state) { + state = std::move(*sock); + emptyfound = true; + break; + } + } + if (!emptyfound) { + joinqueue.push_back(std::move(*sock)); + } + } + + for (auto&& [_, pair] : clients) { + auto&& [c, cd] = pair; + if (c.LostConnection()) { + lostqueue.push_back(&c); + std::cout << "lost connection: " << cd.Name << std::endl; + } + } + + for (auto&& c : joinqueue) { + + if (!c) { + continue; + } + + if (c->Available() <= 0) { + continue; + } + + c->CryptEngine.Init(sharedkey); + + auto cd = c->EncryptionRecv()->Get(); + + if (cd) { + std::cout << "connected: " << cd->Name << std::endl; + auto addr = c->GetPeerAddress(); + clients[*addr] = {std::move(*c), std::move(*cd)}; + c.reset(); + } + } + + while (!lostqueue.empty()) { + auto p = lostqueue.front(); + lostqueue.pop_front(); + + clients.erase(*p->GetPeerAddress()); + } + + for (auto&& [_, pair] : clients) { + auto&& [c, cd] = pair; + + int available = c.Available(); + + if (available <= 0) { + continue; + } + + auto val = c.EncryptionRecv(); + + if (!val) { + continue; + } + + std::string send = cd.Name + "(" + std::to_string(cd.Level) + "): " + *val->Get(); + + std::cout << send << std::endl; + + for (auto&& [_, topair] : clients) { + auto&& [oc, __] = topair; + if (oc == c) { + continue; + } + oc.EncryptionSend(Packet(send)); + } + } + } +} + +void Client() { + + TCPSocket server; + + std::cout << "input connect server address" << std::endl; + std::string str_addr; + std::cin >> str_addr; + + auto op_addr = IPAddress::SolveHostName(str_addr); + + if (!op_addr) { + std::cout << "can't solved address" << std::endl; + return; + } + + std::cout << "input port" << std::endl; + unsigned short port; + std::cin >> port; + + if (server.Connect(op_addr->Port(port))) { + std::cout << "connected server." << std::endl; + } + else { + std::cout << "can't connect server." << std::endl; + return; + } + + server.CryptEngine.Init(sharedkey); + + ClientData _data; + + std::cout << "input your Level\n"; + std::cin >> _data.Level; + std::cout << "input your Name\n"; + std::cin >> _data.Name; + + Packet p = Packet(_data); + + server.EncryptionSend(p); + + bool stopflag = false; + + std::mutex mtx; + + std::thread inputthread = std::thread{ + [&] { + while (!stopflag) { + std::string sendval; + std::cin >> sendval; + + if (sendval == "/exit") { + stopflag = true; + break; + } + + std::lock_guard lock(mtx); + + server.EncryptionSend(Packet(sendval)); + } + } + }; + + while (!stopflag) { + + if (server.LostConnection()) { + break; + } + + if (server.Available() <= 0) { + continue; + } + + auto pak = server.EncryptionRecv(); + + if (!pak) { + continue; + } + + std::lock_guard lock(mtx); + + auto val = *pak->Get(); + std::cout << val << std::endl; + } + + stopflag = true; + + inputthread.join(); +} diff --git a/example/template/template.cpp b/example/template/template.cpp new file mode 100644 index 0000000..c320a3c --- /dev/null +++ b/example/template/template.cpp @@ -0,0 +1,6 @@ +#include "include/Socket.h" + +int main(int argc, char* argv[]) { + + return 0; +} \ No newline at end of file diff --git a/include/Cryptgraphy/MultiWordInt.h b/include/Cryptgraphy/MultiWordInt.h index dc4b176..c7b204b 100644 --- a/include/Cryptgraphy/MultiWordInt.h +++ b/include/Cryptgraphy/MultiWordInt.h @@ -1,6 +1,10 @@ #pragma once #include "common.h" +/// +/// fixed-size +/// + template struct bigint { using count_t = size_t; @@ -17,9 +21,9 @@ struct bigint { static constexpr count_t WordCharSize = WordByte * 2; static constexpr count_t AllBits = Words * WordBits; - static_assert(Words > 0, "invalid WordCount"); - using arr_t = std::array; + using arr_view = std::span; + using arr_ref = std::span; using bits_t = std::bitset; using signed_t = bigint; using unsigned_t = bigint; @@ -81,7 +85,15 @@ struct bigint { /// Assignment Operator Module constexpr bigint& operator=(const bigint& from) noexcept { *m_words = *from.m_words; return *this; } - constexpr bigint& operator=(bigint&& from) noexcept { delete m_words; m_words = from.m_words; from.m_words = nullptr; return *this; } + constexpr bigint& operator=(bigint&& from) noexcept { + if (from.m_words == m_words) { + return *this; + } + delete m_words; + m_words = from.m_words; + from.m_words = nullptr; + return *this; + } constexpr bigint& operator=(word_t from) noexcept requires(!IsSigned) { *this = std::move(bigint(from)); return *this; @@ -207,16 +219,32 @@ struct bigint { constexpr operator unsigned_t& () requires(!IsSigned) { return *this; // TODO: remove unneccesary conversion } + template + requires (std::is_convertible_v>) + constexpr bigint& FromWords(R&& r) { + auto beg = m_words->begin(); + auto end = m_words->end(); + for (const auto&& elem : r) { + if (beg == end) { + break; + } + *(beg++) = elem; + } + for (; beg != end; ++beg) { + *beg = 0; + } + return *this; + } /// Arithmetic Module - static constexpr bool AddBase(word_t *dest, word_t src, bool carry) { + static constexpr bool AddBase(word_t *dest, word_t src, bool carry) noexcept { word_t a = *dest; word_t b = src + static_cast(carry); *dest += b; return (b < src) || (*dest < a); } - constexpr bigint& AssignAdd(const bigint& src) { + constexpr bigint& AssignAdd(const bigint& src) noexcept { bool carry = false; for (count_t i = 0; i < Words; ++i) { carry = AddBase( @@ -265,41 +293,89 @@ struct bigint { return { t1, t2 }; } - constexpr bigint& AssignMul(bigint src) { - - bigint base = *this; - *this = 0; + static constexpr bigint NormalMul(const bigint& x, const bigint& y) { + bigint ret = 0; + for (count_t j = 0; j < Words; ++j) { - const word_t src_word = src.words()[j]; - - if (src_word == 0) { + const word_t y_word = y.words()[j]; + + if (y_word == 0) { continue; } - + word_t carry = 0; bool carryflag = false; - + for (count_t i = 0; i + j < Words; ++i) { word_t temp = carry; - + const auto [lower, upper] = MulBase( - src_word, - base.words()[i] + y_word, + x.words()[i] ); - + carryflag = AddBase(&temp, lower, carryflag); - + carry = upper + carryflag; - + carryflag = AddBase( - std::addressof(this->words()[i + j]), + std::addressof(ret.words()[i + j]), temp, false ); } } + + return ret; + } + static constexpr bigint Karatuba(const bigint& x, const bigint& y) { + bigint ret = 0; - return *this; + if (x == 0 || y == 0) { + return ret; + } + + count_t nbit = std::max(x.GetNBit(), y.GetNBit()); + count_t halfbits = (nbit + (nbit & 1)) / 2; + + if (halfbits <= WordBits * 2) { + ret = NormalMul(x, y); + return ret; + } + + bigint halfmask = (bigint(1) << halfbits) - 1; + + bigint xl = x; + bigint xh = x; + bigint yl = y; + bigint yh = y; + + xl &= halfmask; + xh >>= halfbits; + yl &= halfmask; + yh >>= halfbits; + + bigint z0 = Karatuba(xl, yl); + bigint z2 = Karatuba(xh, yh); + + xl += xh; + yl += yh; + bigint z1 = Karatuba(xl, yl); + + z1 -= z0; + z1 -= z2; + ret += z0; + + z1 <<= halfbits; + z2 <<= (2 * halfbits); + + ret += z1; + ret += z2; + + return ret; + } + constexpr bigint& AssignMul(bigint src) { + return *this = NormalMul(*this, src); } constexpr std::pair AssignDivMod(bigint src) { @@ -347,11 +423,11 @@ struct bigint { constexpr friend bool operator<=(const bigint& lhs, const bigint& rhs) { return lhs.Compare(rhs) <= 0; } constexpr friend bool operator> (const bigint& lhs, const bigint& rhs) { return lhs.Compare(rhs) > 0; } constexpr friend bool operator>=(const bigint& lhs, const bigint& rhs) { return lhs.Compare(rhs) >= 0; } - constexpr bigint& AssignLeftShift(word_t c) { + constexpr bigint& AssignLeftShift(count_t c) { bits() <<= c; return *this; } - constexpr bigint& AssignRightShift(word_t c) { + constexpr bigint& AssignRightShift(count_t c) { if constexpr (IsSigned) { if (this->IsNegative()) { unsigned_t shiftmask = 1; @@ -370,10 +446,10 @@ struct bigint { } return *this; } - constexpr bigint& operator<<=(word_t c) { return AssignLeftShift(c); } - constexpr bigint& operator>>=(word_t c) { return AssignRightShift(c); } - constexpr friend bigint operator<<(bigint lhs, word_t c) { return lhs.AssignLeftShift(c); } - constexpr friend bigint operator>>(bigint lhs, word_t c) { return lhs.AssignRightShift(c); } + constexpr bigint& operator<<=(count_t c) { return AssignLeftShift(c); } + constexpr bigint& operator>>=(count_t c) { return AssignRightShift(c); } + constexpr friend bigint operator<<(bigint lhs, count_t c) { return lhs.AssignLeftShift(c); } + constexpr friend bigint operator>>(bigint lhs, count_t c) { return lhs.AssignRightShift(c); } constexpr bigint& AssignNot() { bits().flip(); return *this; } constexpr bigint operator~() const { return bigint(*this).AssignNot(); } constexpr bigint& AssignAnd(const bigint& src) { bits() &= src.bits(); return *this; } @@ -397,12 +473,14 @@ struct bigint { (c - ('a' - 'A')) : (c); } + static constexpr std::string_view DigitsTable = "0123456789abcdefghijklmnopqrstuvwxyz"; + static constexpr auto DigitsTableUpper = DigitsTable | std::ranges::views::transform([](auto x) { return ToUpper(x); }); + static constexpr std::string WordToString(word_t v, int base) { - constexpr std::string_view list = "0123456789abcdefghijkmnlopqrstuvwxyz"; std::string ret; ret.reserve(WordCharSize); while (v != 0) { - ret.push_back(list[v % base]); + ret.push_back(DigitsTable[v % base]); v /= base; } std::reverse(ret.begin(), ret.end()); @@ -414,13 +492,11 @@ struct bigint { word_t ret = 0; auto getidx = [&](char c) -> size_t { - constexpr std::string_view listlower = "0123456789abcdefghijkmnlopqrstuvwxyz"; - constexpr std::string_view listupper = "0123456789ABCDEFGHIJKMNLOPQRSTUVWXYZ"; - size_t idx = listlower.find(c); + size_t idx = DigitsTable.find(c); if (idx != std::string_view::npos) { return idx; } - return listupper.find(c); + return DigitsTableUpper.find(c); }; for (; it != end; ++it) { @@ -444,7 +520,7 @@ struct bigint { assert((base >= 2 && base <= 36) && "Invalid base"); - auto proc = text.substr(0, text.find_first_not_of("0123456789abcdefghijkmnlopqrstuvwxyzABCDEFGHIJKMNLOPQRSTUVWXYZ")); + auto proc = text.substr(0, text.find_first_not_of("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")); auto it = proc.rbegin(); auto end = proc.rend(); count_t c = 0; @@ -510,7 +586,7 @@ struct bigint { return ret; } constexpr std::string ToBase64() const { - constexpr std::string_view list = "ABCDEFGHIJKNMLOPQRSTUVWXYZabcdefghijknmlopqrstuvwxyz0123456789+/"; + constexpr std::string_view list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; std::string ret; ret.reserve(this->GetNBit() / std::log2(64) + 1); @@ -545,8 +621,6 @@ struct bigint { return ret; } constexpr std::string ToString(int base = 10, bool upper = true, bool padding = false) const { - constexpr std::string_view list = "0123456789abcdefghijkmnlopqrstuvwxyz"; - assert((base >= 2 && base <= 36) && "Invalid base"); word_t word_digits = static_cast(WordBits / std::log2(base)); @@ -622,8 +696,8 @@ struct bigint { /// Internal Resource - constexpr arr_t& words() { return *m_words; } - constexpr const arr_t& words() const { return *m_words; } + constexpr arr_t& words() noexcept { return *m_words; } + constexpr const arr_t& words() const noexcept { return *m_words; } constexpr bits_t& bits() { return *reinterpret_cast(m_words->data()); // TODO: resolve potential undefined behavior } @@ -635,3 +709,216 @@ struct bigint { arr_t* m_words = new arr_t(); }; + +#if 0 + +/// +/// variable-size +/// + +template +struct bigint<0, _sign> { + using count_t = size_t; + using diff_t = int64_t; + using word_t = uint64_t; + using sword_t = int64_t; + + static constexpr bool IsSigned = _sign; + // static constexpr count_t Words = 0; + static constexpr count_t WordByte = sizeof(word_t); + static constexpr count_t WordBits = WordByte * 8; + static constexpr count_t WordCharSize = WordByte * 2; + // static constexpr count_t WordBytes = Words * WordByte; + // static constexpr count_t AllBits = Words * WordBits; + + using arr_t = std::vector; + using arr_view = std::span; + using arr_ref = std::span; + using signed_t = bigint<0, true>; + using unsigned_t = bigint<0, false>; + + constexpr count_t GetWords() const noexcept { + return m_words.size(); + } + constexpr count_t GetWordBytes() const noexcept { + return GetWords() * WordByte; + } + constexpr count_t GetAllBits() const noexcept { + return GetWords() * WordBits; + } + constexpr count_t GetNWord() const noexcept { + for (count_t i = m_words.size(); i-- > 0;) { + if (m_words[i] != 0) { + return i + 1; + } + } + return GetWords(); + } + constexpr count_t GetNBit() const noexcept { + count_t idx = GetWords() - 1; + count_t word_nbit = std::bit_width(m_words[idx]); + return word_nbit == 0 ? GetAllBits() : idx * WordBits + word_nbit; + } + + constexpr void Resize(count_t newsize) noexcept { + m_words.resize(newsize, 0); + } + + static constexpr bool AddBase(word_t src, word_t *dest, bool carry) noexcept { + word_t a = *dest; + word_t b = src + static_cast(carry); + *dest += b; + return (b < src) || (*dest < a); + } + constexpr bigint& AssignAdd(const bigint& rhs) noexcept { + bool carry = false; + for (count_t i = 0; i < GetWords(); ++i) { + carry = AddBase( + rhs.m_words[i], + std::addressof(this->m_words[i]), + carry + ); + } + return *this; + } + + constexpr bigint& AssignNOT() noexcept { + std::transform( + std::execution::unseq, + m_words.begin(), + m_words.end(), + m_words.begin(), + std::bit_not() + ); + return *this; + } + constexpr bigint& AssignAND(const bigint& rhs) noexcept { + std::transform( + std::execution::unseq, + m_words.begin(), + m_words.end(), + rhs.m_words.begin(), + m_words.begin(), + std::bit_and() + ); + return *this; + } + constexpr bigint& AssignOR(const bigint& rhs) noexcept { + std::transform( + std::execution::unseq, + m_words.begin(), + m_words.end(), + rhs.m_words.begin(), + m_words.begin(), + std::bit_or() + ); + return *this; + } + constexpr bigint& AssignXOR(const bigint& rhs) noexcept { + std::transform( + std::execution::unseq, + m_words.begin(), + m_words.end(), + rhs.m_words.begin(), + m_words.begin(), + std::bit_xor() + ); + return *this; + } + static constexpr word_t WordShiftBase(word_t low, word_t high, count_t n) noexcept { + return (low >> n) | (high << (WordBits - n)); + } + constexpr bigint& AssignLeftShift(count_t n) noexcept { + count_t wordshift = n >> std::bit_width(WordBits - 1); + count_t bitshift = n & (WordBits - 1); + + for (count_t i = GetWords() - wordshift; i-- > 0;) { + m_words[i + wordshift] = WordShiftBase( + i == 0 ? 0 : m_words[i - 1], + m_words[i], + bitshift + ); + } + + auto offset = std::min(wordshift, GetWords()); + std::fill(std::execution::unseq, m_words.begin(), m_words.begin() + offset, 0); + + return *this; + } + constexpr bigint& AssignRightShift(count_t n) noexcept { + count_t wordshift = n >> std::bit_width(WordBits - 1); + count_t bitshift = WordBits - (n & (WordBits - 1)); + + for (count_t i = wordshift, c = GetWords(); i < c; ++i) { + m_words[i - wordshift] = WordShiftBase( + m_words[i], + m_words[i + 1], + bitshift + ); + } + + auto offset = std::min(wordshift, GetWords()); + std::fill(std::execution::unseq, m_words.rbegin(), m_words.rbegin() + offset, 0); + + return *this; + } + + constexpr bigint& operator~() noexcept { + return AssignNOT(); + } + constexpr bigint& operator&=(const bigint& rhs) noexcept { + return AssignAND(rhs); + } + constexpr bigint& operator|=(const bigint& rhs) noexcept { + return AssignOR(rhs); + } + constexpr bigint& operator^=(const bigint& rhs) noexcept { + return AssignXOR(rhs); + } + + static constexpr auto Compare(const bigint& lhs, const bigint& rhs) noexcept { + count_t words[2] = {lhs.GetWords(), rhs.GetWords()}; + auto [words_min, words_max] = std::minmax(words[0], words[1]); + + bool is_bigger_l = words_max == words[0]; + + const bigint& longer = is_bigger_l ? lhs : rhs; + const bigint& shorter = is_bigger_l ? rhs : lhs; + + if (!IsZeroInRef(arr_view(longer.m_words).last(words_max - words_min))) { + return is_bigger_l ? + std::strong_ordering::greater : std::strong_ordering::less; + } + + for (count_t i = words_min; i-- > 0;) { + auto com = lhs.m_words[i] <=> rhs.m_words[i]; + if (!std::is_eq(com)) { + return com; + } + } + + return std::strong_ordering::equal; + } + constexpr auto Compare(const bigint& rhs) const noexcept { + return Compare(*this, rhs); + } + friend constexpr auto operator<=>(const bigint& lhs, const bigint& rhs) noexcept { + return Compare(lhs, rhs); + } + +private: + + static constexpr bool IsZeroInRef(arr_view v) noexcept { + for (auto&& elem : v) { + if (elem != 0) { + return false; + } + } + return true; + } + + arr_t m_words{}; + +}; + +#endif \ No newline at end of file diff --git a/include/Cryptgraphy/common.h b/include/Cryptgraphy/common.h index 3849243..56951f2 100644 --- a/include/Cryptgraphy/common.h +++ b/include/Cryptgraphy/common.h @@ -9,9 +9,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include diff --git a/include/Packet.h b/include/Packet.h index 7e20988..f95b37a 100644 --- a/include/Packet.h +++ b/include/Packet.h @@ -89,7 +89,9 @@ struct Header { }; /// -/// Packet +/// Packet +/// | header (16 byte) | data (variable) | +/// | size and data type | raw binary data | /// struct Packet { @@ -98,6 +100,7 @@ struct Packet { using byte_t = SocketDetail::byte_t; using bytearray = SocketDetail::bytearray; + using header_bytes = SocketDetail::cbytearray; using byte_view = SocketDetail::byte_view; using byte_ref = SocketDetail::byte_ref; @@ -109,27 +112,27 @@ struct Packet { template static constexpr bool memcpyable = SocketDetail::memcpyable; - + template static constexpr bool to_byteable = SocketDetail::to_byteable; - + template static constexpr bool from_byteable = SocketDetail::from_byteable; template static constexpr bool cross_convertible = SocketDetail::cross_convertible; + Packet() {}; Packet(const Packet&) = default; Packet(Packet&&) = default; - Packet& operator=(const Packet&) = default; Packet& operator=(Packet&&) = default; - Packet() {}; - Packet(const bytearray&) = delete; - Packet(bytearray&&) = delete; - Packet& operator=(const bytearray&) = delete; - Packet& operator=(bytearray&&) = delete; + static Packet FromBytes(const bytearray& src) { + Packet ret; + ret.m_buffer = src; + return ret; + } Packet(uint32_t id, const void* src, uint32_t size) { Header head(id); @@ -144,42 +147,39 @@ struct Packet { Packet(uint32_t id, const bytearray& data) : Packet(id, data.data(), data.size()) {} template Packet(enumT type, const bytearray& data) requires (is_enum32) : Packet(type, data.data(), data.size()) {} - + template Packet(size_t id, const char(&data)[len]) : Packet(id, std::addressof(data), len - 1) {} template Packet(enumT type, const char(&data)[len]) requires (is_enum32) : Packet(static_cast(type), std::addressof(data), len - 1) {} template - Packet(const char(&data)[len]) : Packet(Header::type_hash_code(), std::addressof(data), len - 1) {} + explicit Packet(const char(&data)[len]) : Packet(Header::type_hash_code(), std::addressof(data), len - 1) {} Packet(uint32_t id, const std::string& data) : Packet(id, data.data(), data.size()) {} template Packet(enumT type, const std::string& data) requires (is_enum32) : Packet(type, data.data(), data.size()) {} - Packet(const std::string& data) : Packet(Header::type_hash_code(), data.data(), data.size()) {} - + explicit Packet(const std::string& data) : Packet(Header::type_hash_code(), data.data(), data.size()) {} + template Packet(uint32_t id, const T& data) requires (memcpyable && !cross_convertible) : Packet(id, std::addressof(data), sizeof(T)) {} template Packet(enumT type, const T& data) requires (is_enum32 && memcpyable && !cross_convertible) : Packet(static_cast(type), std::addressof(data), sizeof(T)) {} template - Packet(const T& data) requires (memcpyable && !cross_convertible) : Packet(Header::type_hash_code(), std::addressof(data), sizeof(T)) {} + explicit Packet(const T& data) requires (memcpyable && !cross_convertible) : Packet(Header::type_hash_code(), std::addressof(data), sizeof(T)) {} template Packet(uint32_t id, const std::vector& data) requires (memcpyable && !cross_convertible) : Packet(id, data.data(), data.size() * sizeof(T)) {} template Packet(enumT type, const std::vector& data) requires (is_enum32 && memcpyable && !cross_convertible) : Packet(static_cast(type), data.data(), data.size() * sizeof(T)) {} template - Packet(const std::vector& data) requires (memcpyable && !cross_convertible) : Packet(Header::type_hash_code>(), data.data(), data.size() * sizeof(T)) {} + explicit Packet(const std::vector& data) requires (memcpyable && !cross_convertible) : Packet(Header::type_hash_code>(), data.data(), data.size() * sizeof(T)) {} template - Packet(uint32_t id, const T& data) requires (cross_convertible) { - bytearray _data = Convert(data); - *this = Packet(id, _data.data(), _data.size()); - } + Packet(uint32_t id, const T& data) requires (cross_convertible) : Packet(id, Convert(data)) {}; template Packet(enumT type, const T& data) requires (is_enum32 && cross_convertible) : Packet(static_cast(type), data) {} template - Packet(const T& data) requires (cross_convertible) : Packet(Header::type_hash_code(), data) {} + explicit Packet(const T& data) requires (cross_convertible) : Packet(Header::type_hash_code(), data) {} template Packet(uint32_t id, const std::vector& data) requires (cross_convertible) { @@ -192,9 +192,9 @@ struct Packet { *this = Packet(id, b.data(), b.size()); } template - Packet(enumT type, const std::vector& data) requires (is_enum32 && cross_convertible) : Packet(static_cast(type), data) {} + Packet(enumT type, const std::vector& data) requires (is_enum32 && cross_convertible) : Packet(static_cast(type), data) {} template - Packet(const std::vector& data) requires (cross_convertible) : Packet(Header::type_hash_code>(), data) {} + explicit Packet(const std::vector& data) requires (cross_convertible) : Packet(Header::type_hash_code>(), data) {} Packet(uint32_t id, std::ifstream& ifs) { @@ -205,7 +205,7 @@ struct Packet { std::istreambuf_iterator begin(ifs); std::istreambuf_iterator end; - std::string data(begin, end); + bytearray data(begin, end); *this = Packet(id, data); } @@ -213,44 +213,20 @@ struct Packet { Packet(enumT type, std::ifstream& ifs) requires (is_enum32) : Packet(static_cast(type), ifs) {} explicit Packet(std::ifstream& ifs) : Packet(Header::type_hash_code(), ifs) {} - /* - - explicit Packet(uint32_t id, const std::filesystem::path& path) { - std::error_code ec; - if (path.empty() || !std::filesystem::exists(path, ec) || ec) { - return; - } - - const auto size = std::filesystem::file_size(path, ec); - if (ec) { - return; - } - - std::ifstream ifs(path, std::ios::binary); + size_t Size() const { return m_buffer.size(); } - if (!ifs.is_open()) { - return; + const bytearray& GetRawPacket() const { return m_buffer; } + std::optional GetRawData() const { + if (CheckHeader(0)) { + return std::nullopt; } - - buf_t data(size); - ifs.read(reinterpret_cast(data.data()), size); - - ifs.close(); - - *this = Packet(id, data); + return byte_view(m_buffer.begin(), m_buffer.end()).subspan(HeaderSize); } - template - explicit Packet(enumT type, const std::filesystem::path& path, Header::enum32 dummy_0 = {}) : Packet(static_cast(type), path) {} - explicit Packet(const std::filesystem::path& path) : Packet(Header::type_hash_code(), path) {} - - */ - - size_t Size() const { return m_buffer.size(); } - - const bytearray& GetBuffer() const { return m_buffer; } - Packet& SetBuffer(bytearray&& src) { - m_buffer = std::move(src); - return *this; + std::optional RefRawData() { + if (CheckHeader(0)) { + return std::nullopt; + } + return byte_ref(m_buffer.begin(), m_buffer.end()).subspan(HeaderSize); } std::optional
GetHeader() const { @@ -280,7 +256,7 @@ struct Packet { auto&& [ret, _] = Convert(byte_view(m_buffer).subspan(HeaderSize)); return ret; } - + template std::optional Get() const requires (std::same_as) { if (CheckHeader()) { @@ -293,13 +269,13 @@ struct Packet { } template - std::optional> GetArray() const requires (memcpyable && !from_byteable){ + std::optional> GetArray() const requires (memcpyable && !from_byteable) { if (CheckHeader()) { return std::nullopt; } size_t dataSize = (m_buffer.size() - HeaderSize) / sizeof(T); std::vector data(dataSize); - std::memcpy(data.data(), m_buffer.data() + HeaderSize, m_buffer.size() - HeaderSize); + std::memcpy(data.data(), m_buffer.data() + HeaderSize, dataSize * sizeof(T)); return data; } @@ -309,7 +285,7 @@ struct Packet { return std::nullopt; } std::vector ret; - byte_view view = byte_view(m_buffer.begin(), HeaderSize); + byte_view view = byte_view(m_buffer).subspan(HeaderSize); while (view.begin() < view.end()) { auto&& [elem, last] = Convert(view); ret.push_back(std::move(elem)); @@ -319,7 +295,7 @@ struct Packet { } template - static bytearray Convert(const T &from) requires (to_byteable) { + static bytearray Convert(const T& from) requires (to_byteable) { return from.ToBytes(); } @@ -329,7 +305,7 @@ struct Packet { byte_view view = ret.FromBytes(from); return {ret, view}; } - + static void StoreBytes(bytearray& dest, const void* src, uint32_t size) { dest.insert(dest.end(), static_cast(src), static_cast(src) + size); } diff --git a/include/Socket.h b/include/Socket.h index 3aa6d0d..5dbd5b9 100644 --- a/include/Socket.h +++ b/include/Socket.h @@ -21,7 +21,7 @@ #include "common.h" #ifdef SOCKET_H_USE_NAMESPACE -namespace NetIO { +namespace Socket { #endif // SOCKET_H_USE_NAMESPACE #include "Cryptgraphy/AES128.h" @@ -335,9 +335,9 @@ class SocketBase { static int Poll(poll_t* fds, unsigned int nfds, int timeout) { #ifdef _MSC_BUILD - int ret = WSAPoll(fds, nfds, 0); + int ret = WSAPoll(fds, nfds, timeout); #else - int ret = poll(fds, nfds, 0); + int ret = poll(fds, nfds, timeout); #endif // _MSC_BUILD return ret; } @@ -378,7 +378,12 @@ class basic_TCPSocket : public sockbase { public: - using bytearray = typename sockbase::bytearray; + using bytearray = SocketDetail::bytearray; + using byte_view = SocketDetail::byte_view; + using byte_ref = SocketDetail::byte_ref; + + template + static constexpr bool memcpyable = SocketDetail::memcpyable; basic_TCPSocket() : sockbase() {} basic_TCPSocket(typename sockbase::IPType addr) : basic_TCPSocket() { @@ -502,10 +507,10 @@ class basic_TCPSocket : public sockbase { return true; } - bool Send(const bytearray& src) { + bool Send(byte_view src) { return RawSend(src.data(), static_cast(src.size())); } - bool Recv(bytearray& dest) { + bool Recv(byte_ref dest) { if (dest.empty()) { return false; } return RawRecv(dest.data(), static_cast(dest.size())); } @@ -514,27 +519,26 @@ class basic_TCPSocket : public sockbase { if (src.CheckHeader()) { return false; } - return Send(src.GetBuffer()); + return Send(src.GetRawPacket()); } std::optional Recv() { - bytearray head(Packet::HeaderSize); - if (!Recv(head)) { + Packet::header_bytes headbuf{}; + if (!Recv(headbuf)) { return std::nullopt; } - Packet pak; - pak.SetBuffer(std::move(head)); - bytearray data(pak.GetHeader()->Size); + Header head = std::bit_cast
(headbuf); + bytearray data(head.Size); if (!Recv(data)) { return std::nullopt; } - return Packet(pak.GetHeader()->Type, data); + return Packet(head.Type, data); } - bool EncryptionSend(const bytearray& src) { - bytearray target; + bool EncryptionSend(byte_view src) { + bytearray target(src.size()); return Encrypt(src, target) && Send(target); } - bool EncryptionRecv(bytearray& dest) { + bool EncryptionRecv(byte_ref dest) { return Recv(dest) && Decrypt(dest, dest); } @@ -542,39 +546,36 @@ class basic_TCPSocket : public sockbase { if (src.CheckHeader()) { return false; } - bytearray data(src.GetBuffer().begin() + Packet::HeaderSize, src.GetBuffer().end()); - bool flag = Encrypt(data, data); - Packet pak = Packet(src.GetHeader()->Type, data); - return flag && Send(pak); + auto head = std::bit_cast(*src.GetHeader()); + return Send(head) && EncryptionSend(*src.GetRawData()); } std::optional EncryptionRecv() { - bytearray head(Packet::HeaderSize); - if (!Recv(head)) { + Packet::header_bytes headbuf{}; + if (!Recv(headbuf)) { return std::nullopt; } - Packet pak; - pak.SetBuffer(std::move(head)); - bytearray data(pak.GetHeader()->Size); + Header head = std::bit_cast
(headbuf); + bytearray data(head.Size); if (!EncryptionRecv(data)) { return std::nullopt; } - return Packet(pak.GetHeader()->Type, data); + return Packet(head.Type, data); } - std::future ASyncSend(const bytearray& src) { - return std::async(std::launch::async, [&]() { - return this->Send(src); + std::future ASyncSend(bytearray&& src) { + return std::async(std::launch::async, [&, target = std::move(src)]() { + return this->Send(target); }); } - std::future ASyncRecv(bytearray& dest) { - return std::async(std::launch::async, [&]() { + std::future ASyncRecv(byte_ref dest) { + return std::async(std::launch::async, [=, this]() { return this->Recv(dest); }); } - std::future ASyncSend(const Packet& src) { - return std::async(std::launch::async, [&]() { - return this->Send(src); + std::future ASyncSend(Packet&& src) { + return std::async(std::launch::async, [&, target = std::move(src)]() { + return this->Send(target); }); } std::future> ASyncRecv() { @@ -583,20 +584,20 @@ class basic_TCPSocket : public sockbase { }); } - std::future ASyncEncryptionSend(const bytearray& src) { - return std::async(std::launch::async, [&]() { - return this->EncryptionSend(src); + std::future ASyncEncryptionSend(bytearray&& src) { + return std::async(std::launch::async, [&, target = std::move(src)]() { + return this->EncryptionSend(target); }); } - std::future ASyncEncryptionRecv(bytearray& dest) { - return std::async(std::launch::async, [&]() { + std::future ASyncEncryptionRecv(byte_ref dest) { + return std::async(std::launch::async, [=, this]() { return this->EncryptionRecv(dest); }); } - std::future ASyncEncryptionSend(const Packet& src) { - return std::async(std::launch::async, [&]() { - return this->EncryptionSend(src); + std::future ASyncEncryptionSend(Packet&& src) { + return std::async(std::launch::async, [&, target = std::move(src)]() { + return this->EncryptionSend(target); }); } std::future> ASyncEncryptionRecv() { @@ -605,6 +606,28 @@ class basic_TCPSocket : public sockbase { }); } + template + bool _Send(const T& target) requires (memcpyable) { + return RawSend(&target, sizeof(T)); + } + template + bool _Recv(T& target) requires (memcpyable) { + return RawRecv(&target, sizeof(T)); + } + + template + std::future _ASyncSend(const T& target) requires (memcpyable) { + return std::async(std::launch::async, [this, target]() { + return this->_Send(target); + }); + } + template + std::future _ASyncRecv(T& target) requires (memcpyable) { + return std::async(std::launch::async, [this, &target]() { + return this->_Recv(target); + }); + } + AES128 CryptEngine; protected: diff --git a/include/common.h b/include/common.h index 716b8bc..2838ed5 100644 --- a/include/common.h +++ b/include/common.h @@ -23,6 +23,8 @@ namespace SocketDetail { using byte_ref = std::span; using bytearray = std::vector; + template + using cbytearray = std::array; template concept enum32 = std::is_enum_v && (sizeof(T) == sizeof(uint32_t)); @@ -42,7 +44,4 @@ namespace SocketDetail { template concept cross_convertible = to_byteable && from_byteable; - - - }