import heapq, sys
python3 = sys.version_info.major >= 3
# ---- Huffman coding core classes ----
# Encodes symbols and writes to a Huffman-coded bit stream.
class HuffmanEncoder(object):
# Constructs a Huffman encoder based on the given bit output stream.
def __init__(self, bitout):
# The underlying bit output stream
self.output = bitout
# The code tree to use in the next write() operation. Must be given a suitable value
# value before calling write(). The tree can be changed after each symbol encoded, as long
# as the encoder and decoder have the same tree at the same point in the code stream.
self.codetree = None
# Encodes the given symbol and writes to the Huffman-coded output stream.
def write(self, symbol):
if not isinstance(self.codetree, CodeTree):
raise ValueError("Invalid current code tree")
bits = self.codetree.get_code(symbol)
for b in bits:
self.output.write(b)
# Reads from a Huffman-coded bit stream and decodes symbols.
class HuffmanDecoder(object):
# Constructs a Huffman decoder based on the given bit input stream.
def __init__(self, bitin):
# The underlying bit input stream
self.input = bitin
# The code tree to use in the next read() operation. Must be given a suitable value
# value before calling read(). The tree can be changed after each symbol decoded, as long
# as the encoder and decoder have the same tree at the same point in the code stream.
self.codetree = None
# Reads from the input stream to decode the next Huffman-coded symbol.
def read(self):
if not isinstance(self.codetree, CodeTree):
raise ValueError("Invalid current code tree")
currentnode = self.codetree.root
while True:
temp = self.input.read_no_eof()
if temp == 0: nextnode = currentnode.leftchild
elif temp == 1: nextnode = currentnode.rightchild
else: raise AssertionError("Invalid value from read_no_eof()")
if isinstance(nextnode, Leaf):
return nextnode.symbol
elif isinstance(nextnode, InternalNode):
currentnode = nextnode
else:
raise AssertionError("Illegal node type")
# A table of symbol frequencies. Mutable. Symbols values are numbered
# from 0 to symbolLimit-1. A frequency table is mainly used like this:
# 0. Collect the frequencies of symbols in the stream that we want to compress.
# 1. Build a code tree that is statically optimal for the current frequencies.
class FrequencyTable(object):
# Constructs a frequency table from the given sequence of frequencies.
# The sequence length must be at least 2, and each value must be non-negative.
def __init__(self, freqs):
self.frequencies = list(freqs) # Make a copy
if len(self.frequencies) < 2:
raise ValueError("At least 2 symbols needed")
if any(x < 0 for x in self.frequencies):
raise ValueError("Negative frequency")
# Returns the number of symbols in this frequency table. The result is always at least 2.
def get_symbol_limit(self):
return len(self.frequencies)
# Returns the frequency of the given symbol in this frequency table. The result is always non-negative.
def get(self, symbol):
self._check_symbol(symbol)
return self.frequencies[symbol]
# Sets the frequency of the given symbol in this frequency table to the given value.
def set(self, symbol, freq):
self._check_symbol(symbol)
if freq < 0:
raise ValueError("Negative frequency")
self.frequencies[symbol] = freq
# Increments the frequency of the given symbol in this frequency table.
def increment(self, symbol):
self._check_symbol(symbol)
self.frequencies[symbol] += 1
# Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception.
def _check_symbol(self, symbol):
if 0 <= symbol < len(self.frequencies):
return
else:
raise ValueError("Symbol out of range")
# Returns a string representation of this frequency table,
# useful for debugging only, and the format is subject to change.
def __str__(self):
result = ""
for (i, freq) in enumerate(self.frequencies):
result += "{}\t{}\n".format(i, freq)
return result
# Returns a code tree that is optimal for the symbol frequencies in this table.
# The tree always contains at least 2 leaves (even if they come from symbols with
# 0 frequency), to avoid degenerate trees. Note that optimal trees are not unique.
def build_code_tree(self):
# Note that if two nodes have the same frequency, then the tie is broken
# by which tree contains the lowest symbol. Thus the algorithm has a
# deterministic output and does not rely on the queue to break ties.
# Each item in the priority queue is a tuple of type (int frequency,
# int lowestSymbol, Node node). As per Python rules, tuples are ordered asceding
# by the lowest differing index, e.g. (0, 0) < (0, 1) < (0, 2) < (1, 0) < (1, 1).
pqueue = []
# Add leaves for symbols with non-zero frequency
for (i, freq) in enumerate(self.frequencies):
if freq > 0:
heapq.heappush(pqueue, (freq, i, Leaf(i)))
# Pad with zero-frequency symbols until queue has at least 2 items
for (i, freq) in enumerate(self.frequencies):
if len(pqueue) >= 2:
break
if freq == 0:
heapq.heappush(pqueue, (freq, i, Leaf(i)))
assert len(pqueue) >= 2
# Repeatedly tie together two nodes with the lowest frequency
while len(pqueue) > 1:
x = heapq.heappop(pqueue) # Tuple of (frequency, lowest symbol, node object)
y = heapq.heappop(pqueue) # Tuple of (frequency, lowest symbol, node object)
z = (x[0] + y[0], min(x[1], y[1]), InternalNode(x[2], y[2])) # Construct new tuple
heapq.heappush(pqueue, z)
# Return the remaining node
return CodeTree(pqueue[0][2], len(self.frequencies))
# A binary tree that represents a mapping between symbols
# and binary strings. There are two main uses of a code tree:
# - Read the root field and walk through the tree to extract the desired information.
# - Call getCode() to get the binary code for a particular encodable symbol.
# The path to a leaf node determines the leaf's symbol's code. Starting from the root, going
# to the left child represents a 0, and going to the right child represents a 1. Constraints:
# - The root must be an internal node, and the tree is finite.
# - No symbol value is found in more than one leaf.
# - Not every possible symbol value needs to be in the tree.
# Illustrated example:
# Huffman codes:
# 0: Symbol A
# 10: Symbol B
# 110: Symbol C
# 111: Symbol D
# Code tree:
# .
# / \
# A .
# / \
# B .
# / \
# C D
class CodeTree(object):
# Constructs a code tree from the given tree of nodes and given symbol limit.
# Each symbol in the tree must have value strictly less than the symbol limit.
def __init__(self, root, symbollimit):
# Recursive helper function
def build_code_list(node, prefix):
if isinstance(node, InternalNode):
build_code_list(node.leftchild , prefix + (0,))
build_code_list(node.rightchild, prefix + (1,))
elif isinstance(node, Leaf):
if node.symbol >= symbollimit:
raise ValueError("Symbol exceeds symbol limit")
if self.codes[node.symbol] is not None:
raise ValueError("Symbol has more than one code")
self.codes[node.symbol] = prefix
else:
raise AssertionError("Illegal node type")
if symbollimit < 2:
raise ValueError("At least 2 symbols needed")
# The root node of this code tree
self.root = root
# Stores the code for each symbol, or None if the symbol has no code.
# For example, if symbol 5 has code 10011, then codes[5] is the tuple (1,0,0,1,1).
self.codes = [None] * symbollimit
build_code_list(root, ()) # Fill 'codes' with appropriate data
# Returns the Huffman code for the given symbol, which is a sequence of 0s and 1s.
def get_code(self, symbol):
if symbol < 0:
raise ValueError("Illegal symbol")
elif self.codes[symbol] is None:
raise ValueError("No code for given symbol")
else:
return self.codes[symbol]
#