From ab97ac5365d6a26070f6d7091651a385adae453e Mon Sep 17 00:00:00 2001 From: xueyumusic <278006819@qq.com> Date: Fri, 31 Jan 2020 19:00:13 +0800 Subject: [PATCH] homo aes compatible --- misc/aes/CMakeLists.txt | 12 ++++++++++++ misc/aes/Test_AES.cpp | 28 +++++++++------------------- misc/aes/homAES.cpp | 33 +++++++++++++++++++++------------ misc/aes/homAES.h | 4 ++-- misc/aes/simpleAES.cpp | 1 + 5 files changed, 45 insertions(+), 33 deletions(-) diff --git a/misc/aes/CMakeLists.txt b/misc/aes/CMakeLists.txt index abb087e28..3c5b4a36a 100644 --- a/misc/aes/CMakeLists.txt +++ b/misc/aes/CMakeLists.txt @@ -1,2 +1,14 @@ +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) +## Use -std=c++14 as default. +set(CMAKE_CXX_STANDARD 14) +## Disable C++ extensions +set(CMAKE_CXX_EXTENSIONS OFF) +## Require full C++ standard +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +project(Test_AES_example + LANGUAGES CXX) + +find_package(helib 1.0.0 EXACT REQUIRED) add_executable(TEST_AES simpleAES.cpp homAES.cpp Test_AES.cpp) target_link_libraries(TEST_AES PUBLIC helib) diff --git a/misc/aes/Test_AES.cpp b/misc/aes/Test_AES.cpp index a596379cb..7c41d2d3b 100644 --- a/misc/aes/Test_AES.cpp +++ b/misc/aes/Test_AES.cpp @@ -7,9 +7,11 @@ namespace std {} using namespace std; namespace NTL {} using namespace NTL; +namespace helib{} using namespace helib; #include #include "homAES.h" -#include "Ctxt.h" +#include "helib/Ctxt.h" +#include "helib/ArgMap.h" static long mValues[][14] = { //{ p, phi(m), m, d, m1, m2, m3, g1, g2, g3,ord1,ord2,ord3, c_m} @@ -39,7 +41,7 @@ extern void Cipher(unsigned char out[16], int main(int argc, char **argv) { - ArgMapping amap; + ArgMap amap; long idx = 0; amap.arg("sz", idx, "parameter-sets: toy=0 through huge=5"); @@ -47,8 +49,8 @@ int main(int argc, char **argv) long c=3; amap.arg("c", c, "number of columns in the key-switching matrices"); - long L=0; - amap.arg("L", L, "# of levels in the modulus chain", "heuristic"); + long N=0; + amap.arg("N", N, "# of bits of the modulus chain"); long B=23; amap.arg("B", B, "# of bits per level (only 64-bit machines)"); @@ -66,17 +68,6 @@ int main(int argc, char **argv) vector gens; vector ords; - if (boot) { - if (L<23) L=23; - if (idx<1) idx=1; // the sz=0 params are incompatible with bootstrapping - } else { -#if (NTL_SP_NBITS<50) - if (L<46) L=46; -#else - if (L<42) L=42; -#endif - } - long p = mValues[idx][0]; // long phim = mValues[idx][1]; long m = mValues[idx][2]; @@ -94,8 +85,7 @@ int main(int argc, char **argv) if (abs(mValues[idx][12])>1) ords.push_back(mValues[idx][12]); cout << "*** Test_AES: c=" << c - << ", L=" << L - << ", B=" << B + << ", N=" << N << ", boot=" << boot << ", packed=" << packed << ", m=" << m @@ -107,10 +97,10 @@ int main(int argc, char **argv) cout << "computing key-independent tables..." << std::flush; Context context(m, p, /*r=*/1, gens, ords); #if (NTL_SP_NBITS>=50) // 64-bit machines - context.bitsPerLevel = B; + //context.bitsPerLevel = B; #endif context.zMStar.set_cM(mValues[idx][13]/100.0); // the ring constant - buildModChain(context, L, c); + buildModChain(context, N, c); if (boot) context.makeBootstrappable(mvec); tm += GetTime(); diff --git a/misc/aes/homAES.cpp b/misc/aes/homAES.cpp index 4415fbba8..8c2af6cb5 100644 --- a/misc/aes/homAES.cpp +++ b/misc/aes/homAES.cpp @@ -2,6 +2,7 @@ */ namespace std {} using namespace std; namespace NTL {} using namespace NTL; +namespace helib {} using namespace helib; #include #include "homAES.h" @@ -92,6 +93,8 @@ static void packCtxt(vector& to, const vector& from, static void unackCtxt(vector& to, const vector& from, const Mat& unpackConsts); +static long findBaseLevel(const Ctxt& c); + // Implementation of the class HomAES static const uint8_t aesPolyBytes[] = { 0x1B, 0x1 }; // X^8+X^4+X^3+X+1 @@ -99,12 +102,12 @@ const GF2X HomAES::aesPoly = GF2XFromBytes(aesPolyBytes, 2); HomAES::HomAES(const Context& context): ea2(context,aesPoly,context.alMod) #ifndef USE_ZZX_POLY // initialize DoubleCRT using the context -, affVec(context) +, affVec(context,context.allPrimes()) #endif { // Sanity-check: we need the first dimension to be divisible by 16. //OLD: assert( context.zMStar.OrderOf(0) % 16 == 0 ); - helib::assertEq(context.zMStar.OrderOf(0) % 16, 0l); + helib::assertEq(context.zMStar.OrderOf(0) % 16, 0l, "The first dimension need to be divisible by 16"); // Compute the GF2-affine transformation constants buildAffineEnc(encAffMat, affVec, ea2); @@ -171,7 +174,7 @@ void HomAES::setPackingConstants() long e = ea.getDegree() / 8; // the extension degree //OLD: assert(ea.getDegree()==e*8 && e<=(long) sizeof(long)); - helib::assertEq(ea.getDegree()==e*8, "ea must have degree divisible by 8"); + helib::assertEq(ea.getDegree(), e*8, "ea must have degree divisible by 8"); helib::assertTrue(e<=(long) sizeof(long), "extension degree must be at most 8 times sizeof(long)"); GF2EBak bak; bak.save(); // save current modulus (if any) @@ -218,14 +221,14 @@ void HomAES::homAESenc(vector& eData, const vector& aesKey) const for (long i=1; i<(long)aesKey.size(); i++) { // apply the AES rounds // ByteSub - if (eData[0].findBaseLevel() < 4) batchRecrypt(eData); + if (findBaseLevel(eData[0]) < 4) batchRecrypt(eData); invert(eData); // apply Z -> Z^{-1} to all elements of eData #ifdef DEBUG_PRINTOUT CheckCtxt(eData[0], "+ After invert"); // cerr << " + After invert "; // decryptAndPrint(cerr, eData[0], *dbgKey, *dbgEa); #endif - if (eData[0].findBaseLevel() < 2) batchRecrypt(eData); + if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData); for (long j=0; j<(long)eData.size(); j++) { // GF2 affine transformation applyLinPolyLL(eData[j], encAffMat, ea2.getDegree()); eData[j].addConstant(affVec); @@ -237,7 +240,7 @@ void HomAES::homAESenc(vector& eData, const vector& aesKey) const #endif // Apply RowShift/ColMix to each ciphertext - if (eData[0].findBaseLevel() < 2) batchRecrypt(eData); + if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData); if (i<(long)aesKey.size()-1) { for (long j=0; j<(long)eData.size(); j++) encRowColTran(eData[j], encLinTran, ea2); @@ -295,7 +298,7 @@ void HomAES::homAESdec(vector& eData, const vector& aesKey) const for (long j=0; j<(long)eData.size(); j++) eData[j] -= aesKey[i]; // Apply RowShift/ColMix to each ciphertext - if (eData[0].findBaseLevel() < 2) batchRecrypt(eData); + if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData); // if (eData[0].log_of_ratio() > (-lvlBits)) batchRecrypt(eData); if (i<(long)aesKey.size()-1) for (long j=0; j<(long)eData.size(); j++) @@ -311,7 +314,7 @@ void HomAES::homAESdec(vector& eData, const vector& aesKey) const #endif // ByteSub - if (eData[0].findBaseLevel() < 2) batchRecrypt(eData); + if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData); for (long j=0; j<(long)eData.size(); j++) { // GF2 affine transformation eData[j].addConstant(affVec); applyLinPolyLL(eData[j], decAffMat, ea2.getDegree()); @@ -321,7 +324,7 @@ void HomAES::homAESdec(vector& eData, const vector& aesKey) const // cerr << " + After affine "; // decryptAndPrint(cerr, eData[0], *dbgKey, *dbgEa); #endif - if (eData[0].findBaseLevel() < 4) batchRecrypt(eData); + if (findBaseLevel(eData[0]) < 4) batchRecrypt(eData); invert(eData); // apply Z -> Z^{-1} to all elements of eData #ifdef DEBUG_PRINTOUT CheckCtxt(eData[0], "+ After invert"); @@ -440,7 +443,7 @@ static void buildAffine(vector& binMat, PolyType* binVec, ea2.encode(zzxMat[j], scratch); // encode these slots } #ifndef USE_ZZX_POLY - binMat.resize(8,DoubleCRT(ea2.getContext())); + binMat.resize(8,DoubleCRT(ea2.getContext(),ea2.getContext().allPrimes())); for (long j=0; j<8; j++) binMat[j] = zzxMat[j]; // convert to DoubleCRT #endif @@ -478,7 +481,7 @@ static void buildLinEnc(vector& encLinTran, #ifdef USE_ZZX_POLY encLinTran.resize(6); #else - encLinTran.resize(6,DoubleCRT(ea2.getContext())); + encLinTran.resize(6,DoubleCRT(ea2.getContext(),ea2.getContext().allPrimes())); #endif for (long i=0; i<3; i++) { // constants for the RowShift/ColMix trans for (long j=0; j& decLinTran, #ifdef USE_ZZX_POLY decLinTran.resize(8); #else - decLinTran.resize(8,DoubleCRT(ea2.getContext())); + decLinTran.resize(8,DoubleCRT(ea2.getContext(),ea2.getContext().allPrimes())); #endif for (long i=0; i<4; i++) { // constants for the RowShift/ColMix trans for (long j=0; j& to, const vector& from, } } } + +// A hack to get this to compile for now +static long findBaseLevel(const Ctxt& c) +{ + return long(c.naturalSize() / 23); // FIXME: replace 23 by something else +} diff --git a/misc/aes/homAES.h b/misc/aes/homAES.h index 07ca36efa..92568c6c7 100644 --- a/misc/aes/homAES.h +++ b/misc/aes/homAES.h @@ -3,8 +3,8 @@ #include #include #include -#include "EncryptedArray.h" -#include "hypercube.h" +#include "helib/EncryptedArray.h" +#include "helib/hypercube.h" #ifdef USE_ZZX_POLY #define PolyType ZZX diff --git a/misc/aes/simpleAES.cpp b/misc/aes/simpleAES.cpp index 9639f7885..16de0123c 100644 --- a/misc/aes/simpleAES.cpp +++ b/misc/aes/simpleAES.cpp @@ -33,6 +33,7 @@ Find the Wikipedia page of AES at: // Used for giving output to the screen. #include #include +#include // The number of columns comprising a state in AES. This is a constant in AES. // Value=4