#include<windows.h>
#include<iostream>
#include<math.h>
#include<stdio.h>
using namespace std;
typedef int Elemtype;
Elemtype p, q, e;
Elemtype fn;
Elemtype m, c;
int flag = 0;
typedef void(*Msghandler)(void);
struct MsgMap
{
char ch;
Msghandler handler;
};
/*公钥*/
struct PU
{
Elemtype e;
Elemtype n;
} pu;
/*私钥*/
struct PR
{
Elemtype d;
Elemtype n;
} pr;
/*判定一个数是否为素数*/
bool test_prime(Elemtype m)
{
if (m <= 1)
{
return false;
}
else if (m == 2)
{
return true;
}
else
{
for (int i = 2; i <= sqrt(m); i++)
{
if ((m % i) == 0)
{
return false;
break;
}
}
return true;
}
}
/*将十进制数据转化为二进制数组*/
void switch_to_bit(Elemtype b, Elemtype bin[32])
{
int n = 0;
while (b > 0)
{
bin[n] = b % 2;
n++;
b /= 2;
}
}
/*初始化主界面*/
void Init(){
cout<<"******************************************"<<endl;
cout<<"*** Welcome to use RSA encoder ***"<<endl;
cout<<"*** 1.setkey ***"<<endl;
cout<<"*** 2.加密 ***"<<endl;
cout<<"*** 3.解密 ***"<<endl;
cout<<"*** 4.退出 ***"<<endl;
cout<<"******************************************"<<endl;
cout<<"press a key:"<<endl;
}
/*将两个数排序,大的在前面*/
void order(Elemtype &in1, Elemtype &in2)
{
Elemtype a = (in1 > in2 ? in1 : in2);
Elemtype b = (in1 < in2 ? in1 : in2);
in1 = a;
in2 = b;
}
/*求最大公约数*/
Elemtype gcd(Elemtype a, Elemtype b)
{
order(a, b);
int r;
if (b == 0)
{
return a;
}
else
{
while (true)
{
r = a % b;
a = b;
b = r;
if (b == 0)
{
return a;
break;
}
}
}
}
/*用扩展的欧几里得算法求乘法逆元*/
Elemtype extend_euclid(Elemtype m, Elemtype bin)
{
order(m, bin);
Elemtype a[3], b[3], t[3];
a[0] = 1, a[1] = 0, a[2] = m;
b[0] = 0, b[1] = 1, b[2] = bin;
if (b[2] == 0)
{
return a[2] = gcd(m, bin);
}
if (b[2] == 1)
{
return b[2] = gcd(m, bin);
}
while (true)
{
if (b[2] == 1)
{
return b[1];
break;
}
int q = a[2] / b[2];
for (int i = 0; i < 3; i++)
{
t[i] = a[i] - q * b[i];
a[i] = b[i];
b[i] = t[i];
}
}
}
/*快速模幂算法*/
Elemtype modular_multiplication(Elemtype a, Elemtype b, Elemtype n)
{
Elemtype f = 1;
Elemtype bin[32];
switch_to_bit(b, bin);
for (int i = 32; i >= 0; i--)
{
f = (f * f) % n;
if (bin[i] == 1)
{
f = (f * a) % n;
}
}
return f;
}
/*产生密钥*/
void produce_key()
{
cout << "输入素数 p 和 q:";
cin >> p >> q;
while (!(test_prime(p) && test_prime(q)))
{
cout << "输入错误,请重新输入!" << endl;
cout << "输入素数 p 和 q:";
cin >> p >> q;
};
pr.n = p * q;
pu.n = p * q;
fn = (p - 1) * (q - 1);
cout << "fn为:" << fn << endl;
cout << "输入随机数e:";
cin >> e;
while ((gcd(fn, e) != 1))
{
cout << "e输入错误,请重新输入!" << endl;
cout << "输入随机数e:";
cin >> e;
}
pr.d = (extend_euclid(fn, e) + fn) % fn;
pu.e = e;
flag = 1;
cout << "公钥(e,n):" << pu.e << "," << pu.n << endl;
cout << "私钥d:" << pr.d << endl;
cout << "请输入下一步操作序号:" << endl;
}
/*加密*/
void encrypt()
{
if (flag == 0)
{
cout << "setkey first:" << endl;
produce_key();
}
cout << "输入明文 m:";
cin >> m;
c = modular_multiplication(m, pu.e, pu.n);
cout << "密文c 为:" << c << endl;
cout << "请输入下一步操作序号:" << endl;
}
/*解密*/
void decrypt()
{
if (flag == 0)
{
cout << "setkey first:" << endl;
produce_key();
}
cout << "输入密文 c:";
cin >> c;
m = modular_multiplication(c, pr.d, pr.n);
cout << "明文m 为:" << m << endl;
cout << "请输入下一步操作序号:" << endl;
}
/*消息映射*/
MsgMap Messagemap[] =
{
{ '1', produce_key },
{ '3', decrypt },
{ '2', encrypt },
{ '4', NULL }
};
/*主函数,提供循环*/
int main()
{
Init();
char d;
while ((d = getchar()) != '4')
{
int i = 0;
while (Messagemap[i].ch)
{
if (Messagemap[i].ch == d)
{
Messagemap[i].handler();
break;
}
i++;
}
}
}
使输出明文和输入的明文一致