Source file src/image/jpeg/huffman.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jpeg
     6  
     7  import (
     8  	"io"
     9  )
    10  
    11  // maxCodeLength is the maximum (inclusive) number of bits in a Huffman code.
    12  const maxCodeLength = 16
    13  
    14  // maxNCodes is the maximum (inclusive) number of codes in a Huffman tree.
    15  const maxNCodes = 256
    16  
    17  // lutSize is the log-2 size of the Huffman decoder's look-up table.
    18  const lutSize = 8
    19  
    20  // huffman is a Huffman decoder, specified in section C.
    21  type huffman struct {
    22  	// length is the number of codes in the tree.
    23  	nCodes int32
    24  	// lut is the look-up table for the next lutSize bits in the bit-stream.
    25  	// The high 8 bits of the uint16 are the encoded value. The low 8 bits
    26  	// are 1 plus the code length, or 0 if the value is too large to fit in
    27  	// lutSize bits.
    28  	lut [1 << lutSize]uint16
    29  	// vals are the decoded values, sorted by their encoding.
    30  	vals [maxNCodes]uint8
    31  	// minCodes[i] is the minimum code of length i, or -1 if there are no
    32  	// codes of that length.
    33  	minCodes [maxCodeLength]int32
    34  	// maxCodes[i] is the maximum code of length i, or -1 if there are no
    35  	// codes of that length.
    36  	maxCodes [maxCodeLength]int32
    37  	// valsIndices[i] is the index into vals of minCodes[i].
    38  	valsIndices [maxCodeLength]int32
    39  }
    40  
    41  // errShortHuffmanData means that an unexpected EOF occurred while decoding
    42  // Huffman data.
    43  var errShortHuffmanData = FormatError("short Huffman data")
    44  
    45  // ensureNBits reads bytes from the byte buffer to ensure that d.bits.n is at
    46  // least n. For best performance (avoiding function calls inside hot loops),
    47  // the caller is the one responsible for first checking that d.bits.n < n.
    48  func (d *decoder) ensureNBits(n int32) error {
    49  	for {
    50  		c, err := d.readByteStuffedByte()
    51  		if err != nil {
    52  			if err == io.ErrUnexpectedEOF {
    53  				return errShortHuffmanData
    54  			}
    55  			return err
    56  		}
    57  		d.bits.a = d.bits.a<<8 | uint32(c)
    58  		d.bits.n += 8
    59  		if d.bits.m == 0 {
    60  			d.bits.m = 1 << 7
    61  		} else {
    62  			d.bits.m <<= 8
    63  		}
    64  		if d.bits.n >= n {
    65  			break
    66  		}
    67  	}
    68  	return nil
    69  }
    70  
    71  // receiveExtend is the composition of RECEIVE and EXTEND, specified in section
    72  // F.2.2.1.
    73  //
    74  // It returns the signed integer that's encoded in t bits, where t < 16. The
    75  // possible return values are:
    76  //
    77  //   - t ==  0:   0
    78  //   - t ==  1:   -1, +1
    79  //   - t ==  2:   -3, -2, +2, +3
    80  //   - t ==  3:   -7, -6, -5, -4, +4, +5, +6, +7
    81  //   - ...
    82  //   - t == 15:   -32767, -32766, ..., -16384, +16384, ..., +32766, +32767
    83  func (d *decoder) receiveExtend(t uint8) (int32, error) {
    84  	if d.bits.n < int32(t) {
    85  		if err := d.ensureNBits(int32(t)); err != nil {
    86  			return 0, err
    87  		}
    88  	}
    89  	d.bits.n -= int32(t)
    90  	d.bits.m >>= t
    91  	s := int32(1) << t
    92  	x := int32(d.bits.a>>uint8(d.bits.n)) & (s - 1)
    93  
    94  	// This adjustment, assuming two's complement, is a branchless equivalent of:
    95  	//
    96  	// if x < s>>1 {
    97  	//   x += ((-1) << t) + 1
    98  	// }
    99  	//
   100  	// sign is either -1 or 0, depending on whether x is in the low or high
   101  	// half of the range 0 .. 1<<t.
   102  	sign := (x >> (t - 1)) - 1
   103  	x += sign & (((-1) << t) + 1)
   104  
   105  	return x, nil
   106  }
   107  
   108  // processDHT processes a Define Huffman Table marker, and initializes a huffman
   109  // struct from its contents. Specified in section B.2.4.2.
   110  func (d *decoder) processDHT(n int) error {
   111  	for n > 0 {
   112  		if n < 17 {
   113  			return FormatError("DHT has wrong length")
   114  		}
   115  		if err := d.readFull(d.tmp[:17]); err != nil {
   116  			return err
   117  		}
   118  		tc := d.tmp[0] >> 4
   119  		if tc > maxTc {
   120  			return FormatError("bad Tc value")
   121  		}
   122  		th := d.tmp[0] & 0x0f
   123  		// The baseline th <= 1 restriction is specified in table B.5.
   124  		if th > maxTh || (d.baseline && th > 1) {
   125  			return FormatError("bad Th value")
   126  		}
   127  		h := &d.huff[tc][th]
   128  
   129  		// Read nCodes and h.vals (and derive h.nCodes).
   130  		// nCodes[i] is the number of codes with code length i.
   131  		// h.nCodes is the total number of codes.
   132  		h.nCodes = 0
   133  		var nCodes [maxCodeLength]int32
   134  		for i := range nCodes {
   135  			nCodes[i] = int32(d.tmp[i+1])
   136  			h.nCodes += nCodes[i]
   137  		}
   138  		if h.nCodes == 0 {
   139  			return FormatError("Huffman table has zero length")
   140  		}
   141  		if h.nCodes > maxNCodes {
   142  			return FormatError("Huffman table has excessive length")
   143  		}
   144  		n -= int(h.nCodes) + 17
   145  		if n < 0 {
   146  			return FormatError("DHT has wrong length")
   147  		}
   148  		if err := d.readFull(h.vals[:h.nCodes]); err != nil {
   149  			return err
   150  		}
   151  
   152  		// Derive the look-up table.
   153  		clear(h.lut[:])
   154  		var x, code uint32
   155  		for i := uint32(0); i < lutSize; i++ {
   156  			code <<= 1
   157  			for j := int32(0); j < nCodes[i]; j++ {
   158  				// The codeLength is 1+i, so shift code by 8-(1+i) to
   159  				// calculate the high bits for every 8-bit sequence
   160  				// whose codeLength's high bits matches code.
   161  				// The high 8 bits of lutValue are the encoded value.
   162  				// The low 8 bits are 1 plus the codeLength.
   163  				base := uint8(code << (7 - i))
   164  				lutValue := uint16(h.vals[x])<<8 | uint16(2+i)
   165  				for k := uint8(0); k < 1<<(7-i); k++ {
   166  					h.lut[base|k] = lutValue
   167  				}
   168  				code++
   169  				x++
   170  			}
   171  		}
   172  
   173  		// Derive minCodes, maxCodes, and valsIndices.
   174  		var c, index int32
   175  		for i, n := range nCodes {
   176  			if n == 0 {
   177  				h.minCodes[i] = -1
   178  				h.maxCodes[i] = -1
   179  				h.valsIndices[i] = -1
   180  			} else {
   181  				h.minCodes[i] = c
   182  				h.maxCodes[i] = c + n - 1
   183  				h.valsIndices[i] = index
   184  				c += n
   185  				index += n
   186  			}
   187  			c <<= 1
   188  		}
   189  	}
   190  	return nil
   191  }
   192  
   193  // decodeHuffman returns the next Huffman-coded value from the bit-stream,
   194  // decoded according to h.
   195  func (d *decoder) decodeHuffman(h *huffman) (uint8, error) {
   196  	if h.nCodes == 0 {
   197  		return 0, FormatError("uninitialized Huffman table")
   198  	}
   199  
   200  	if d.bits.n < 8 {
   201  		if err := d.ensureNBits(8); err != nil {
   202  			if err != errMissingFF00 && err != errShortHuffmanData {
   203  				return 0, err
   204  			}
   205  			// There are no more bytes of data in this segment, but we may still
   206  			// be able to read the next symbol out of the previously read bits.
   207  			// First, undo the readByte that the ensureNBits call made.
   208  			if d.bytes.nUnreadable != 0 {
   209  				d.unreadByteStuffedByte()
   210  			}
   211  			goto slowPath
   212  		}
   213  	}
   214  	if v := h.lut[(d.bits.a>>uint32(d.bits.n-lutSize))&0xff]; v != 0 {
   215  		n := (v & 0xff) - 1
   216  		d.bits.n -= int32(n)
   217  		d.bits.m >>= n
   218  		return uint8(v >> 8), nil
   219  	}
   220  
   221  slowPath:
   222  	for i, code := 0, int32(0); i < maxCodeLength; i++ {
   223  		if d.bits.n == 0 {
   224  			if err := d.ensureNBits(1); err != nil {
   225  				return 0, err
   226  			}
   227  		}
   228  		if d.bits.a&d.bits.m != 0 {
   229  			code |= 1
   230  		}
   231  		d.bits.n--
   232  		d.bits.m >>= 1
   233  		if code <= h.maxCodes[i] {
   234  			return h.vals[h.valsIndices[i]+code-h.minCodes[i]], nil
   235  		}
   236  		code <<= 1
   237  	}
   238  	return 0, FormatError("bad Huffman code")
   239  }
   240  
   241  func (d *decoder) decodeBit() (bool, error) {
   242  	if d.bits.n == 0 {
   243  		if err := d.ensureNBits(1); err != nil {
   244  			return false, err
   245  		}
   246  	}
   247  	ret := d.bits.a&d.bits.m != 0
   248  	d.bits.n--
   249  	d.bits.m >>= 1
   250  	return ret, nil
   251  }
   252  
   253  func (d *decoder) decodeBits(n int32) (uint32, error) {
   254  	if d.bits.n < n {
   255  		if err := d.ensureNBits(n); err != nil {
   256  			return 0, err
   257  		}
   258  	}
   259  	ret := d.bits.a >> uint32(d.bits.n-n)
   260  	ret &= (1 << uint32(n)) - 1
   261  	d.bits.n -= n
   262  	d.bits.m >>= uint32(n)
   263  	return ret, nil
   264  }
   265  

View as plain text