Code:
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <errno.h>
#include <stdio.h>
/*
* Note: number[size] is an overflow word that should be allocated, but zero.
*/
/* Duplicate or create a new a natural number.
*/
uint32_t *duplicate(const uint32_t *const old_number, const size_t old_size, const size_t new_size)
{
uint32_t *new_number;
new_number = malloc((1 + new_size) * sizeof (uint32_t));
if (!new_number) {
errno = ENOMEM;
return NULL;
}
if (old_size > 0)
memcpy(new_number, old_number, old_size * sizeof (uint32_t));
memset(new_number + old_size, 0, (1 + new_size - old_size) * sizeof (uint32_t));
return new_number;
}
/* Return the number of words used in natural number, or zero if it is zero.
*/
static inline size_t words(const uint32_t *const number, const size_t size)
{
size_t i = size;
while (i-- > 0)
if (number[i])
return i + 1;
return 0;
}
/* Compare two natural numbers. Like strcmp(), but returns -1, 0, 1 only.
*/
int compare(const uint32_t *const number1, const size_t size1,
const uint32_t *const number2, const size_t size2)
{
size_t i;
if (size1 > size2) {
i = size1;
while (i-- > size2)
if (number1[i])
return +1;
} else
if (size2 > size1) {
i = size2;
while (i-- > size1)
if (number2[i])
return -1;
} else
i = size1; /* == size2 */
while (i-- > 0)
if (number1[i] < number2[i])
return -1;
else
if (number1[i] > number2[i])
return +1;
return 0;
}
/* Average two natural numbers of same size.
*/
uint32_t average(uint32_t *const result,
const uint32_t *const number1,
const uint32_t *const number2,
const size_t size)
{
uint64_t temp;
uint32_t carry = 0U;
size_t i;
for (i = 0; i < size; i++) {
temp = (uint64_t)carry + (uint64_t)number1[i] + (uint64_t)number2[i];
carry = temp >> 32U;
result[i] = temp;
}
i = size;
while (i-- > 0) {
const uint32_t high = carry << 31U;
carry = result[i];
result[i] = high | (carry >> 1U);
}
return carry << 31U;
}
/* Subtract number2 from number1. result may be one or neither.
*/
int32_t subtract(uint32_t *const result,
const uint32_t *const number1,
const uint32_t *const number2,
const size_t size)
{
int64_t temp;
int32_t carry = 0;
size_t i;
for (i = 0; i < size; i++) {
temp = (int64_t)carry + (int64_t)number1[i] - (int64_t)number2[i];
carry = temp >> 32;
result[i] = temp;
}
return carry;
}
/* Multiply two natural number of same size.
* Note: this will write to result[size1 + size2].
*/
uint32_t multiply(uint32_t *const result,
const uint32_t *const number1, const size_t size1,
const uint32_t *const number2, const size_t size2)
{
uint64_t temp;
size_t i1, i2, i;
memset(result, 0, (size1 + size2 + 1) * sizeof (uint32_t));
for (i1 = 0; i1 < size1; i1++) {
for (i2 = 0; i2 < size2; i2++) {
temp = (uint64_t)result[i1 + i2]
+ (uint64_t)number1[i1] * (uint64_t)number2[i2];
result[i1 + i2] = temp;
temp >>= 32U;
if (temp) {
i = i1 + i2 + 1;
while (i <= size1 + size2) {
temp += (uint64_t)result[i];
result[i++] = temp;
temp >>= 32U;
}
}
}
}
return result[size1 + size2];
}
/* Scale a natural number by a 32-bit constant, and add a 32-bit constant. Returns overflow.
*/
uint32_t small_mul_add(uint32_t *const number, const size_t size, const uint32_t scale, const uint32_t add)
{
uint64_t temp;
size_t i;
uint32_t carry = add;
for (i = 0; i < size; i++) {
temp = (uint64_t)number[i] * (uint64_t)scale + (uint64_t)carry;
carry = temp >> 32U;
number[i] = temp;
}
return carry;
}
/* Divide a natural number by a 32-bit constant. Returns the remainder.
*/
uint32_t small_div(uint32_t *const number, const size_t size, const uint32_t divisor)
{
uint64_t temp;
uint64_t carry = 0U;
size_t i = size;
/* Skip zero high words. */
while (i > 0 && !number[i - 1])
i--;
carry = 0U;
while (i-->0) {
temp = (carry << 32U) + (uint64_t)number[i];
number[i] = temp / (uint64_t)divisor;
carry = temp % (uint64_t)divisor;
}
return carry;
}
/* Helper function: Return first non-white-space character, or NULL.
*/
static inline const char *nonspace(const char *string)
{
if (!string)
return NULL;
while (*string == '\t' || *string == '\n' || *string == '\v' ||
*string == '\f' || *string == '\r' || *string == ' ')
string++;
if (*string == '\0')
return NULL;
return string;
}
/* Parse a decimal string into a natural number.
*/
static const char *string_to_natural(const char *string, uint32_t **numberptr, size_t *sizeptr)
{
uint32_t *number = NULL;
uint32_t digit;
size_t size = 0;
size_t used = 0;
/* Invalid parameters? */
if (!numberptr || !sizeptr) {
errno = EINVAL;
return NULL;
}
/* Locate start of number, skipping spaces. */
string = nonspace(string);
if (!string) {
errno = ENOENT;
return NULL;
}
/* Make sure it is a decimal number. */
if (!(string[0] >= '0' && string[0] <= '9')) {
errno = ENOENT;
return NULL;
}
while (1) {
if (*string >= '0' && *string <= '9')
digit = *(string++) - '0';
else
break;
if (used >= size) {
const size_t new_size = used + 16;
uint32_t *new_number;
new_number = realloc(number, (1 + new_size) * sizeof (uint32_t));
if (!new_number) {
free(number);
errno = ENOMEM;
return NULL;
}
number = new_number;
size = new_size;
}
digit = small_mul_add(number, used, 10, digit);
if (digit)
number[used++] = digit;
}
/* Reallocate to used size. */
{ uint32_t *new_number;
new_number = realloc(number, (1 + used) * sizeof (uint32_t));
if (!new_number) {
free(number);
errno = ENOMEM;
return NULL;
}
number = new_number;
number[used] = 0U;
}
*numberptr = number;
*sizeptr = used;
return string;
}
/* Return a dynamically allocated string describing a natural number.
*/
char *natural_to_string(const uint32_t *const number, const size_t size)
{
const size_t length = 10 * size + 10;
size_t n;
char *buffer, *p;
uint32_t *value, digits;
buffer = malloc(length + 1);
value = duplicate(number, size, size);
if (!buffer || !value) {
free(value);
free(buffer);
errno = ENOMEM;
return NULL;
}
p = buffer + length;
*p = '\0';
do {
digits = small_div(value, size, 1000000000U);
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + (digits % 10U); digits /= 10U;
*(--p) = '0' + digits;
} while (words(value, size));
while (*p == '0')
p++;
if (*p == '\0')
p--;
n = (size_t)(buffer + length - p);
if (p > buffer)
memmove(buffer, p, n + 1); /* Include '\0' */
free(value);
return buffer;
}
/* Compute the modulus of number, number <= (modulus-1)*(modulus-1).
*/
int modulo(uint32_t *const result,
const uint32_t *const number,
const uint32_t *const modulus,
const size_t size)
{
uint32_t *product;
uint32_t *factor;
uint32_t *minimum;
uint32_t *maximum;
uint32_t *temp;
product = duplicate(NULL, 0, 2 * size);
factor = duplicate(NULL, 0, size);
minimum = duplicate(NULL, 0, size);
maximum = duplicate(modulus, size, size);
if (!product || !factor || !minimum || !maximum) {
free(maximum);
free(minimum);
free(factor);
free(product);
return errno = ENOMEM;
}
multiply(product, modulus, size, modulus, size);
if (compare(product, 2*size, number, 2*size) <= 0) {
free(maximum);
free(minimum);
free(factor);
free(product);
return errno = EDOM;
}
while (1) {
average(factor, minimum, maximum, size);
multiply(product, factor, size, modulus, size);
if (!compare(factor, size, minimum, size))
break;
if (compare(product, 2*size, number, 2*size) < 0) {
temp = minimum;
minimum = factor;
factor = temp;
} else {
temp = maximum;
maximum = factor;
factor = temp;
}
}
free(maximum);
free(minimum);
free(factor);
subtract(product, number, product, 2*size);
if (words(product, 2*size) > size) {
free(product);
return errno = EDOM;
}
memcpy(result, product, size * sizeof (uint32_t));
free(product);
return 0;
}
int main(int argc, char *argv[])
{
uint32_t *number, *modulus, *square, *result;
size_t number_size, modulus_size;
const char *tail;
char *string;
if (argc != 3 || !strcmp(argv[1], "-h") || !strcmp(argv[1], "--help")) {
fprintf(stderr, "\n");
fprintf(stderr, "Usage: %s number modulus\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "Given modulus > number > 0, this computes\n");
fprintf(stderr, " (number*number) mod modulus\n");
fprintf(stderr, "Both inputs and outputs are in decimal.\n");
fprintf(stderr, "\n");
return 0;
}
tail = string_to_natural(argv[1], &number, &number_size);
if (!tail) {
fprintf(stderr, "%s: Not a positive decimal integer.\n", argv[1]);
return 1;
}
tail = nonspace(tail);
if (tail) {
fprintf(stderr, "%s: Garbage (%s) at end of number.\n", argv[1], tail);
return 1;
}
tail = string_to_natural(argv[2], &modulus, &modulus_size);
if (!tail) {
fprintf(stderr, "%s: Not a positive decimal integer.\n", argv[2]);
return 1;
}
tail = nonspace(tail);
if (tail) {
fprintf(stderr, "%s: Garbage (%s) at end of number.\n", argv[2], tail);
return 1;
}
if (compare(number, number_size, modulus, modulus_size) >= 0) {
fprintf(stderr, "%s: Modulus must be larger than number.\n", argv[2]);
return 1;
}
square = duplicate(NULL, 0, 2 * modulus_size);
result = duplicate(NULL, 0, modulus_size);
if (!square || !result) {
fprintf(stderr, "Not enough memory.\n");
return 1;
}
multiply(square, number, number_size, number, number_size);
free(number);
if (modulo(result, square, modulus, modulus_size)) {
fprintf(stderr, "modulo() bugged out (%s).\n", strerror(errno));
return 1;
}
free(square);
free(modulus);
string = natural_to_string(result, modulus_size);
if (!string) {
fprintf(stderr, "Not enough memory.\n");
return 1;
}
fputs(string, stdout);
fputc('\n', stdout);
free(result);
free(string);
return 0;
}
In its current form, it obviously only does one iteration. The number and modulus are supplied as command line parameters.