// Burrows Wheeler Transform Encoder/Decoder

#ifdef unix
#define __cdecl
#endif

#include <stdlib.h>
#include <memory.h>
#include <string.h>

//  these two values are stored together
//  to improve processor cache hits

typedef struct {
    unsigned prefix, offset;
} KeyPrefix;

//  link to suffix sort module

extern KeyPrefix *bwtsort(unsigned char *, unsigned);

//  these functions link bit-level I/O

void arc_put1 (int bit);
void arc_put8 (int byte);
int arc_get1 ();
int arc_get8 ();

//  define 1/2 rle zero alphabet bits

#define HUFF_bit0 256
#define HUFF_bit1 257

//  the size of the Huffman alphabet

#define HUFF_size 258

//  store these values together for bwt decoding

typedef struct {
    unsigned code:8;
    unsigned cnt:24;
} Xform;

//  the HuffMan table for each alphabet character

typedef struct {
    unsigned len;
    unsigned code;
} HuffTable;

HuffTable HuffCode[HUFF_size];

//  used to construct the HuffCode table

struct Node {
    struct Node *left, *right;
    unsigned freq;
};

unsigned Freq[HUFF_size], ZeroCnt;  // alphabet counts
unsigned char MtfOrder[256];        // move-to-front

//    enumerate coding tree depths

unsigned enumerate (unsigned *codes, struct Node *node, unsigned depth)
{
unsigned one, two;

    if( !node->right ) {    // leaf node?
        HuffCode[(int)(node->left)].len = depth;
        codes[depth]++;
        return depth;
    }

    one = enumerate (codes, node->left, depth + 1);
    two = enumerate (codes, node->right, depth + 1);

    // return the max depth of the two sub-trees

    return one > two ? one : two;
}

int __cdecl comp_node (const void *left, const void *right)
{
    return ((struct Node *)left)->freq - ((struct Node *)right)->freq;
}

// construct Huffman coding tree

void encode_tree ()
{
struct Node tree[2 * HUFF_size], *base = tree, *left, *right;
unsigned codes[32], rank[32], weight;
int idx, max;
int size;

    // the node tree is built with all the base symbols
    // then constructed nodes are appended

    memset (HuffCode, 0, sizeof(HuffCode));
    memset (tree, 0, sizeof(tree));

    // sort base symbol nodes by their frequencies

    for( size = 0; size < HUFF_size; size++ ) {
        tree[size].left = (void *)size;
        tree[size].freq = Freq[size];
        tree[size].right = NULL;    // indicates a base node
    }

    qsort (tree, HUFF_size, sizeof(struct Node), comp_node);

    // repeatedly combine & remove two lowest freq nodes 
    // then construct a new node w/sum of these two freq
    // and insert from the end of the tree (base + size)

    while( size-- > 1 ) {
        left = base;

        if( weight = (base++)->freq )
            weight += base->freq; 
        else
            continue;    // skip over zero freq nodes

        right = base++;
        idx = size;

        // sort new node into place

        while( --idx )
          if( base[idx-1].freq > weight )
            base[idx] = base[idx-1];
          else
            break;

        // construct the new internal node

        base[idx].freq = weight;
        base[idx].right = right;
        base[idx].left = left;
    }

    // base points at root of tree (size == 1)
    // construct the Huffman code down from here

    memset (codes, 0, sizeof(codes));
    memset (rank, 0, sizeof(rank));

    // enumerate the left & right subtrees,
    // returns the deepest path to leaves

    max = enumerate (rank, base, 0);

    // use cannonical Huffman coding technique
    // (from Steve Pigeon)

    for( idx = 0; idx <= max; idx++ )
        codes[idx + 1] = (codes[idx] + rank[idx]) << 1, rank[idx] = 0;

    // set the code for each non-zero freq alphabet symbol 

    for( idx = 0; idx < HUFF_size; idx++ )
      if( HuffCode[idx].len )
        HuffCode[idx].code = codes[HuffCode[idx].len] + rank[HuffCode[idx].len]++;
}

//    output code bits for one alphabet symbol

unsigned huff_encode (unsigned val)
{
unsigned mask = 1 << HuffCode[val].len;
unsigned code = HuffCode[val].code;

    while( mask >>= 1 )
        arc_put1 (code & mask);

    return code;
}

//  perform run-length-encoding
//  using two new Huffman codes
//  for RLE count bits 0 & 1

void rle_encode (unsigned char *out, unsigned size)
{
unsigned count, idx;
unsigned mask, code;

    // transmit cannonical huff coding tree by
    // sending 5 bits for each symbol's length

    for( idx = 0; idx < HUFF_size; idx++ ) {
        count = HuffCode[idx].len;
        mask = 1 << 5;

        while( mask >>= 1 )
            arc_put1 (count & mask);
    }

    // accumulate RLE counts and encode BWT

    // repeated zeroes are first counted,
    // this count is transmitted in binary
    // using 2 special HUFF alphabet symbols
    // HUFF_bit0 and HUFF_bit1, representing
    // count values 1 & 2:

    // transmit HUFF_bit0 = count of 1
    // transmit HUFF_bit1 = count of 2
    // transmit HUFF_bit0, HUFF_bit0 = count of 3
    // transmit HUFF_bit0, HUFF_bit1 = count of 4
    // transmit HUFF_bit1, HUFF_bit0 = count of 5
    // transmit HUFF_bit1, HUFF_bit1 = count of 6 ...

    // to make decoding simpler, transmit final
    // zero code separately from RLE

    count = 0;

    while( size-- ) {
      if( !(code = *out++) && size ) {
        count++;           // accumulate RLE count
        continue;          // except for trailing RLE
      }

      while( count )  // transmit any RLE count bits
        huff_encode (HUFF_bit0 + (--count & 0x1)), count >>= 1;

      huff_encode (code);
    }
}

//    Move-to-Front decoder

unsigned mtf_decode (unsigned nxt)
{
unsigned char code;

//  Pull the char
//
    code = MtfOrder[nxt];
//
//  Now shuffle the order array
//
    memmove (MtfOrder + 1, MtfOrder, nxt);
    return MtfOrder[0] = code;
}

// expand BWT into the supplied buffer

void rle_decode (Xform *xform, unsigned size, unsigned last)
{
unsigned xlate[HUFF_size], length[HUFF_size];
unsigned codes[32], rank[32], base[32], bits;
unsigned nxt, count, lvl, idx, out = 0, zero;
unsigned char prev;

    // construct decode table

    memset (codes, 0, sizeof(codes));
    memset (rank, 0, sizeof(rank));

    // retrieve code lengths, 5 bits each

    for( idx = 0; idx < HUFF_size; idx++ ) {
      for( length[idx] = bits = 0; bits < 5; bits++ )
        length[idx] <<= 1, length[idx] |= arc_get1();
      rank[length[idx]]++;
    }

    // construct cannonical Huffman code groups
    // one group range for each bit length

    base[0] = base[1] = 0;

    for( idx = 1; idx < 30; idx++ ) {
        codes[idx + 1] = (codes[idx] + rank[idx]) << 1;
        base[idx + 1] = base[idx] + rank[idx];
        rank[idx] = 0;
    }

    // fill in the translated Huffman codes
    // filling in ranks for each code group

    for( nxt = idx = 0; idx < HUFF_size; idx++ )
      if( lvl = length[idx] )
        xlate[base[lvl] + rank[lvl]++] = idx;

    zero = prev = count = bits = lvl = 0;

    // fill supplied buffer by reading the input
    // one bit at a time and assembling codes

    while( out < size ) {
      bits <<= 1, bits |= arc_get1 ();

      if( rank[++lvl] )
        if( bits < codes[lvl] + rank[lvl] )
          nxt = xlate[base[lvl] + bits - codes[lvl]];
        else
          continue;  // the code is above the range for this length
      else
        continue;    // no symbols with this code length

      // nxt = the recognized symbol
      // reset code accumulator

      bits = lvl = 0;

      // process RLE count bits as 1 or 2

      if( nxt > 255 ) {
        count += ( nxt - 255 ) << zero++;
        continue;
      }

      // expand previous RLE count

      while( count ) {
        if( out != last )
            xform[out].cnt = Freq[prev]++;

        xform[out++].code = prev;
        count--;
      }

      zero = 0;
      prev = mtf_decode (nxt);   // translate mtf of the symbol

      if( out != last )
          xform[out].cnt = Freq[prev]++;

      xform[out++].code = prev;  // store next symbol
    }
}

//    Move-to-Front encoder, and
//    accumulate frequency counts
//    using RLE coding (not for flush)

unsigned char mtf_encode (unsigned char val, int flush)
{
unsigned code;

    code = (unsigned char *)memchr (MtfOrder, val, 256) - MtfOrder;
    memmove (MtfOrder + 1, MtfOrder, code);
    MtfOrder[0] = val;
    
    if( !flush && !code )
        return ZeroCnt++, code;

    //  accumulate the frequency counts for the
    //  new code and the previous zero run

    Freq[code]++;

    while( ZeroCnt )
        Freq[HUFF_bit0 + (--ZeroCnt & 0x1)]++, ZeroCnt >>= 1;

    return code;
}

//    initialize Move-to-Front symbols

void mtf_init ()
{
unsigned idx;

    for( idx = 0 ; idx < 256 ; idx++ )
        MtfOrder[idx] = (unsigned char)idx;
}

// unpack next bwt segment from current stream into buffer

void bwt_decode (unsigned char *outbuff, unsigned buflen)
{
unsigned last, idx = 0;
Xform *xform;
unsigned ch;

    xform = malloc ((buflen + 1 ) * sizeof(Xform));
    mtf_init ();

    // retrieve last row number

    last = arc_get8 () << 16;
    last |= arc_get8 () << 8;
    last |= arc_get8 ();

// To determine a character's position in the output string given
// its position in the input string, we can use the knowledge about
// the fact that the output string is sorted.  Each character 'c' will
// show up in the output stream in in position i, where i is the sum
// total of all characters in the input buffer that precede c in the
// alphabet (kept in the count array), plus the count of all
// occurences of 'c' previously in the block (kept in xform.cnt)

// The first part of this code calculates the running totals for all
// the characters in the alphabet.  That satisfies the first part of the
// equation needed to determine where each 'c' will go in the output
// stream. Remember that the character pointed to by 'last' is a special
// end-of-buffer character that needs to be larger than any other char
// so we just skip over it while tallying counts

    memset (Freq, 0, sizeof(Freq));
    rle_decode (xform, buflen + 1, last);

    for( idx = 1 ; idx < 256 ; idx++ )
        Freq[idx] += Freq[idx-1];

// Once the transformation vector is in place, writing the
// output is just a matter of computing the indices.  Note
// that we ignore the EOB from the end of data first, and
// process the array backwards from there

    last = idx = buflen;

    while( idx-- ) {
        ch = outbuff[idx] = xform[last].code;
        last = xform[last].cnt;

        if( ch-- )
            last += Freq[ch];
    }

    free (xform);
}

// pack next bwt segment into current stream

void bwt_encode (unsigned char *buff, unsigned max)
{
unsigned idx, last, off, size = 0;
unsigned char *out;
KeyPrefix *keys;

    // zero freq counts

    memset (Freq, 0, sizeof(Freq));
    ZeroCnt = 0;

    keys = bwtsort (buff, max);
    out = malloc (max + 1);

//    Finally, write out column L.  Column L consists of all
//    the prefix characters to the sorted strings, in order.
//    It's easy to get the prefix character, but offset 0
//    is handled with care, since its prefix character 
//    is the imaginary end-of-buffer character.  Save the
//    positions in L of the end-of-buffer character and
//    and write it out at the end to the output stream.

    mtf_init ();
    idx = 0;

    do if( off = keys[idx].offset )
          out[size++] = mtf_encode (buff[--off], 0);
        else
          last = idx, out[size++] = mtf_encode (MtfOrder[0], 0);
    while( ++idx < max );

    out[size++] = mtf_encode (buff[max - 1], 1);

    // transmit where the EOB is located

    arc_put8 ((unsigned char)(last >> 16));
    arc_put8 ((unsigned char)(last >> 8));
    arc_put8 ((unsigned char)(last));

    // construct huff coding tree and transmit code-lengths

    encode_tree ();

    // encode and transmit output

    rle_encode (out, size);

    free (keys);
    free (out);
}

#ifdef CODERSTANDALONE

#include <stdio.h>

unsigned char ArcBit = 0, ArcChar = 0;
FILE *In, *Out;

int main (int argc, char **argv)
{
int mode, max, size, nxt;
unsigned char *buff;

    if( argc > 1 )
        mode = argv[1][0];
    else
        return 1;

    if( !(In = fopen (argv[2], "rb")) )
        return 1;

    if( !(Out = fopen (argv[3], "wb")) )
        return 1;

    //  decompression

    while( mode == 'd' ) {
        size = getc (In);

        if( size < 0 )
            return 0;

        for( nxt = 0; nxt < 2; nxt++ )
            size <<= 8, size |= getc (In);

        ArcBit = 0;

        if( size ) {
            buff = malloc (size);
            bwt_decode (buff, size);
        }

        for( nxt = 0; nxt < size; nxt++ )
            putc (buff[nxt], Out);

        if( size )
            free (buff);
    }

    // compression

    fseek(In, 0, 2);
    size = ftell(In);
    fseek (In, 0, 0);

    do {
        if( max = size > 900000 ? 900000 : size )
            buff = malloc (max);

        putc ((unsigned char)(max >> 16), Out);
        putc ((unsigned char)(max >> 8), Out);
        putc ((unsigned char)(max), Out);

        for( nxt = 0; nxt < max; nxt++ )
            buff[nxt] = getc(In);

        if( max )
            bwt_encode (buff, max), free (buff);

        while( ArcBit )  // flush last few bits
           arc_put1 (0);

    } while( size -= max );

    return 0;
}

void arc_put1 (int bit)
{
    ArcChar <<= 1;

    if( bit )
        ArcChar |= 1;

    if( ++ArcBit < 8 )
        return;

    putc (ArcChar, Out);
    ArcChar = ArcBit = 0;
}

void arc_put8 (int ch)
{
int idx = 8;

    while( idx-- )
        arc_put1 (ch & 1 << idx);
}

int arc_get1 ()
{
    if( !ArcBit )
        ArcChar = getc (In), ArcBit = 8;

    return ArcChar >> --ArcBit & 1;
}

int arc_get8 ()
{
int idx, result = 0;

    for( idx = 0; idx < 8; idx++ )
        result <<= 1, result |= arc_get1();

    return result;
}
#endif