Hi all,
I hope this isn't overload but I've included code for an implementation of the rsa algorithm below that I've been working on. There are 3 files "keygen.cpp", "numtheory.cpp" and "numtheory.h". They compile correctly and I think my functions are correct but something is going seriously wrong. The files encrypt but don't decrypt properly. I would very much appreciate it if anyone could help me out and give a look at the code and see where I might be gone wrong. Thanks a lot.
Code:
-------------keygen.cpp----------------------------
#include "numtheory.h"
#include <stdlib.h>
#include <stdio.h>
#include <ctime>
#include <iostream>
#include <fcntl.h>
#include <string>
#include <cstdlib>
using namespace std;
void encrypt(const char* file, number exponent, number modulus);
void decrypt(const char* file, number exponent, number modulus);
void generatePandQ(number minimum, number maximum, number& p, number& q)
{
// pick p and q
p = pickRandomPrime(minimum, maximum);
q = pickRandomPrime(minimum, maximum);
// ensure p and q are not the same.
while(q == p)
{
q = pickRandomPrime(minimum, maximum);
}
}
int main()
{
srand(time(0));
cout << "Enter minimum value for p and q: " << flush;
number minimum; cin >> minimum;
cout << "Enter maximum value for p and q: " << flush;
number maximum; cin >> maximum;
number p, q;
generatePandQ(minimum, maximum, p, q);
number N = p*q;
number c = (p-1)*(q-1);
number e;
for(e = c / 4; gcd(c, e) != 1; ++e);
number d = modinverse(e, c);
cout << "p is " << p << endl;
cout << "q is " << q << endl;
cout << "N is " << N << endl;
cout << "c is " << c << endl;
cout << "e is " << e << endl;
cout << "d is " << d << endl;
char fileEncrypt[50];
char fileDecrypt[50];
cout<<"Enter name of file: " << endl;
cin>>fileEncrypt;
encrypt(fileEncrypt, e, N);
cout << "Enter name of file to decrypt: " << endl;
cin >> fileDecrypt;
decrypt(fileDecrypt, d, N);
return 0;
}
void encrypt(const char file[], number exponent, number modulus)
{
cout << "Encrypting: exponent is " << exponent << ", N is " << modulus << endl;
const int BUF_SIZE = 8192;
int charsRead = 0, i;
char input_char;
char file_out[30];
unsigned char input_buffer[BUF_SIZE];
unsigned int output_buffer[BUF_SIZE];
FILE* input_file;
FILE* output_file;
strcpy(file_out, file);
strcat(file_out, ".encrypted");
input_file = fopen(file, "r");
if (input_file==NULL) {
cout << "Error opening file";
exit(-1);
}
else {
output_file = fopen(file_out, "w");
while(1)
{
charsRead = 0;
// Step 1: Read BUF_SIZE bytes of file into input buffer.
for (i = 0; i <BUF_SIZE; i++)
{
input_char = fgetc(input_file);
if(input_char == EOF)
break;
else {
input_buffer[i] = input_char;
charsRead++;
}
}
if (charsRead == 0)
{
// done with file.
break;
}
// Step 2: Encrypt this block of data and store in output buffer.
for(i = 0; i < charsRead * sizeof(unsigned int); ++i)
{
output_buffer[i] = powermod( input_buffer[i], exponent, modulus );
fputc(output_buffer[i], output_file);
}
}
}
cout << "Done encrypting." << endl;
fclose(input_file);
fclose(output_file);
}
void decrypt(const char file[], number exponent, number modulus)
{
cout << "Decrypting: exponent is " << exponent << ", N is " << modulus << endl;
const int BUF_SIZE = 8192;
int charsRead = 0, i;
char input_char;
char file_out[30];
unsigned int input_buffer[BUF_SIZE];
unsigned char output_buffer[BUF_SIZE];
FILE* input_file;
FILE* output_file;
strcpy(file_out, file);
strcat(file_out, ".decrypted");
input_file = fopen(file, "r");
if (input_file==NULL) {
cout << "Error opening file";
exit(-1);
}
else {
output_file = fopen(file_out, "w");
while(1)
{
charsRead = 0;
// Step 1: Read BUF_SIZE bytes of file into input buffer.
for (i = 0; i <BUF_SIZE * sizeof(number); i++)
{
input_char = fgetc(input_file);
if(input_char == EOF)
break;
else {
input_buffer[i] = input_char;
charsRead++;
}
}
if (charsRead == 0)
{
// done with file.
break;
}
// Step 2: Encrypt this block of data and store in output buffer.
for(i = 0; i < charsRead / sizeof(number); ++i)
{
output_buffer[i] = (unsigned char)powermod( input_buffer[i], exponent, modulus );
fputc(output_buffer[i], output_file);
}
}
}
cout << "Done encrypting." << endl;
fclose(input_file);
fclose(output_file);
}
Code:
---------------numtheory.cpp------------------------
#include "numtheory.h"
#include <assert.h>
#include <cmath>
#include <vector>
#include <iostream>
number mod(number x, number modulus)
{
assert(modulus > 0);
x %= modulus;
if (x < 0)
x += modulus;
return x;
}
//
// powermod function
// -----------------
// this function raises a number to a power, under a given modulus.
// it can handle huge exponentiation... basically as long as modulus
// squared won't overflow the number, powermod won't overflow it.
//
number powermod(number value, number power, number modulus)
{
int temp;
number ret(1);
number t_value;
for(unsigned int i = 0;i < (sizeof(power)*8);++i)
{
if (i == 0)
{
t_value = value;
}
else
{
t_value *= t_value;
t_value %= modulus;
}
temp =((power >> i) & 1);
if ( temp == 1)
{
// the ith bit is 1, so we are going to
// need to include this t value in our
// answer.
ret *= t_value;
ret %= modulus;
}
}
return ret;
}
//
// gcd function
// ------------
// this is a function that calculates the gcd between 2
// numbers using the euclidian greatest common divisor method.
//
number gcd(number high, number low)
{
if (high < low)
{
number temp(high);
high = low;
low = temp;
}
number z(0);
while(low > 0)
{
z = high % low;
high = low;
low = z;
}
return high;
}
//
// gcd_combination function
// ------------------------
// this is a function that takes the gcd of two numbers, a and
// b, and also returns the coefficients to create a linear
// combination of a and b, using the euclidian common divisor
// method.
//
number gcd_combination(number high, number low, number& s, number& t)
{
//
// We require that high is greater than low.
//
assert(high > low);
assert(high > 0 && low > 0);
//
// If low divides high, our algorithm won't return the
// correct s and t values. So, we handle this case
// right now.
//
if (high % low == 0)
{
s = 1;
t = -1 * ((high / low) - 1);
return low;
}
//
// We need to seed the s and t values with their correct
// starting values.
//
s = 1;
t = 0;
//
// We store the last two t's and s's that were calculated
// for use in computing future t's and s's.
//
number old_t[2];
number old_s[2];
//
// The q array is where we store our quotients.
// q[2] is the current quotient, q[1] is one back from
// that, and q[0] is one back from q[1].
//
number q[3];
for(number j = 0; (true); ++j)
{
q[0] = q[1];
q[1] = q[2];
old_s[0] = old_s[1];
old_s[1] = s;
old_t[0] = old_t[1];
old_t[1] = t;
if (j == 1)
{
s = 0;
t = 1;
}
else if (j > 1)
{
s = old_s[0] - (q[0] * old_s[1]);
t = old_t[0] - (q[0] * old_t[1]);
}
q[2] = number(high / low);
number remainder(high - (low * q[2]));
if (remainder == 0)
{
old_s[0] = old_s[1];
old_s[1] = s;
old_t[0] = old_t[1];
old_t[1] = t;
q[0] = q[1];
q[1] = q[2];
s = old_s[0] - (q[0] * old_s[1]);
t = old_t[0] - (q[0] * old_t[1]);
return low;
}
high = low;
low = remainder;
}
}
number modinverse(number x, number modulus)
{
x = mod(x, modulus);
number s, t;
number gcd = gcd_combination(modulus, x, s, t);
assert(gcd == 1);
return mod(t, modulus);
}
number pickRandomPrime(number low, number high)
{
// Pick a random odd number in our interval.
number x = (low + (rand() % (high - low))) | 1;
while( (powermod(2, x-1, x) != 1) ||
(powermod(3, x-1, x) != 1) ||
(powermod(5, x-1, x) != 1) ||
(powermod(7, x-1, x) != 1) ||
(powermod(11, x-1, x) != 1) )
{
// If any of those conditions are true, we have
// found a composite number. Try the next one.
x += 2;
}
// once we get here, we have found a prime.
return x;
}
//
// factor_trialdiv function
// ------------------------
//
// This is a function that attempts to find a factor of a number
// using the trial division method.
//
// It returns the first factor found, or 1 if the number is prime.
//
number factor_trialdiv(number n)
{
for(number i = 2;i <= sqrt(n);++i)
{
if ((n % i) == 0)
{
return i;
}
}
return 1;
}
Code:
-------------------------numtheory.h--------------------
#ifndef INCLUDED_NUMBER_THEORY_H
#define INCLUDED_NUMBER_THEORY_H
typedef long number;
number mod(number x, number modulus);
number powermod(number value, number power, number modulus);
number gcd(number high, number low);
number gcd_combination(number high, number low, number& s, number& t);
number modinverse(number x, number modulus);
number pickRandomPrime(number low, number high);
number factor_trialdiv(number n);
#endif // INCLUDED_NUMBER_THEORY_H
---------------------END-----------------------------