1
2
3
4
5
6 package base64
7
8 import (
9 "internal/byteorder"
10 "io"
11 "slices"
12 "strconv"
13 )
14
15
18
19
20
21
22
23
24 type Encoding struct {
25 encode [64]byte
26 decodeMap [256]uint8
27 padChar rune
28 strict bool
29 }
30
31 const (
32 StdPadding rune = '='
33 NoPadding rune = -1
34 )
35
36 const (
37 decodeMapInitialize = "" +
38 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
39 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
40 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
41 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
42 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
43 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
44 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
45 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
46 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
47 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
48 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
49 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
50 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
51 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
52 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
53 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
54 invalidIndex = '\xff'
55 )
56
57
58
59
60
61
62
63
64 func NewEncoding(encoder string) *Encoding {
65 if len(encoder) != 64 {
66 panic("encoding alphabet is not 64-bytes long")
67 }
68
69 e := new(Encoding)
70 e.padChar = StdPadding
71 copy(e.encode[:], encoder)
72 copy(e.decodeMap[:], decodeMapInitialize)
73
74 for i := 0; i < len(encoder); i++ {
75
76
77
78 switch {
79 case encoder[i] == '\n' || encoder[i] == '\r':
80 panic("encoding alphabet contains newline character")
81 case e.decodeMap[encoder[i]] != invalidIndex:
82 panic("encoding alphabet includes duplicate symbols")
83 }
84 e.decodeMap[encoder[i]] = uint8(i)
85 }
86 return e
87 }
88
89
90
91
92
93
94
95
96 func (enc Encoding) WithPadding(padding rune) *Encoding {
97 switch {
98 case padding < NoPadding || padding == '\r' || padding == '\n' || padding > 0xff:
99 panic("invalid padding")
100 case padding != NoPadding && enc.decodeMap[byte(padding)] != invalidIndex:
101 panic("padding contained in alphabet")
102 }
103 enc.padChar = padding
104 return &enc
105 }
106
107
108
109
110
111
112
113 func (enc Encoding) Strict() *Encoding {
114 enc.strict = true
115 return &enc
116 }
117
118
119 var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
120
121
122
123 var URLEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
124
125
126
127
128 var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
129
130
131
132
133 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
134
135
138
139
140
141
142
143
144
145 func (enc *Encoding) Encode(dst, src []byte) {
146 if len(src) == 0 {
147 return
148 }
149
150
151
152 _ = enc.encode
153
154 for len(src) >= 3 {
155
156 val := uint(src[0])<<16 | uint(src[1])<<8 | uint(src[2])
157
158 _ = dst[3]
159 dst[0] = enc.encode[val>>18&0x3F]
160 dst[1] = enc.encode[val>>12&0x3F]
161 dst[2] = enc.encode[val>>6&0x3F]
162 dst[3] = enc.encode[val&0x3F]
163
164 src = src[3:]
165 dst = dst[4:]
166 }
167
168
169 switch len(src) {
170 case 0:
171 return
172 case 1:
173 val := uint(src[0]) << 16
174 dst[0] = enc.encode[val>>18&0x3F]
175 dst[1] = enc.encode[val>>12&0x3F]
176 if enc.padChar != NoPadding {
177 dst[2] = byte(enc.padChar)
178 dst[3] = byte(enc.padChar)
179 }
180 case 2:
181 val := uint(src[0])<<16 | uint(src[1])<<8
182 dst[0] = enc.encode[val>>18&0x3F]
183 dst[1] = enc.encode[val>>12&0x3F]
184 dst[2] = enc.encode[val>>6&0x3F]
185 if enc.padChar != NoPadding {
186 dst[3] = byte(enc.padChar)
187 }
188 }
189 }
190
191
192
193 func (enc *Encoding) AppendEncode(dst, src []byte) []byte {
194 n := enc.EncodedLen(len(src))
195 dst = slices.Grow(dst, n)
196 enc.Encode(dst[len(dst):][:n], src)
197 return dst[:len(dst)+n]
198 }
199
200
201 func (enc *Encoding) EncodeToString(src []byte) string {
202 buf := make([]byte, enc.EncodedLen(len(src)))
203 enc.Encode(buf, src)
204 return string(buf)
205 }
206
207 type encoder struct {
208 err error
209 enc *Encoding
210 w io.Writer
211 buf [3]byte
212 nbuf int
213 out [1024]byte
214 }
215
216 func (e *encoder) Write(p []byte) (n int, err error) {
217 if e.err != nil {
218 return 0, e.err
219 }
220
221
222 if e.nbuf > 0 {
223 var i int
224 for i = 0; i < len(p) && e.nbuf < 3; i++ {
225 e.buf[e.nbuf] = p[i]
226 e.nbuf++
227 }
228 n += i
229 p = p[i:]
230 if e.nbuf < 3 {
231 return
232 }
233 e.enc.Encode(e.out[:], e.buf[:])
234 if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
235 return n, e.err
236 }
237 e.nbuf = 0
238 }
239
240
241 for len(p) >= 3 {
242 nn := len(e.out) / 4 * 3
243 if nn > len(p) {
244 nn = len(p)
245 nn -= nn % 3
246 }
247 e.enc.Encode(e.out[:], p[:nn])
248 if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
249 return n, e.err
250 }
251 n += nn
252 p = p[nn:]
253 }
254
255
256 copy(e.buf[:], p)
257 e.nbuf = len(p)
258 n += len(p)
259 return
260 }
261
262
263
264 func (e *encoder) Close() error {
265
266 if e.err == nil && e.nbuf > 0 {
267 e.enc.Encode(e.out[:], e.buf[:e.nbuf])
268 _, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
269 e.nbuf = 0
270 }
271 return e.err
272 }
273
274
275
276
277
278
279 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
280 return &encoder{enc: enc, w: w}
281 }
282
283
284
285 func (enc *Encoding) EncodedLen(n int) int {
286 if enc.padChar == NoPadding {
287 return n/3*4 + (n%3*8+5)/6
288 }
289 return (n + 2) / 3 * 4
290 }
291
292
295
296 type CorruptInputError int64
297
298 func (e CorruptInputError) Error() string {
299 return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
300 }
301
302
303
304
305
306
307 func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
308
309 var dbuf [4]byte
310 dlen := 4
311
312
313 _ = enc.decodeMap
314
315 for j := 0; j < len(dbuf); j++ {
316 if len(src) == si {
317 switch {
318 case j == 0:
319 return si, 0, nil
320 case j == 1, enc.padChar != NoPadding:
321 return si, 0, CorruptInputError(si - j)
322 }
323 dlen = j
324 break
325 }
326 in := src[si]
327 si++
328
329 out := enc.decodeMap[in]
330 if out != 0xff {
331 dbuf[j] = out
332 continue
333 }
334
335 if in == '\n' || in == '\r' {
336 j--
337 continue
338 }
339
340 if rune(in) != enc.padChar {
341 return si, 0, CorruptInputError(si - 1)
342 }
343
344
345 switch j {
346 case 0, 1:
347
348 return si, 0, CorruptInputError(si - 1)
349 case 2:
350
351
352 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
353 si++
354 }
355 if si == len(src) {
356
357 return si, 0, CorruptInputError(len(src))
358 }
359 if rune(src[si]) != enc.padChar {
360
361 return si, 0, CorruptInputError(si - 1)
362 }
363
364 si++
365 }
366
367
368 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
369 si++
370 }
371 if si < len(src) {
372
373 err = CorruptInputError(si)
374 }
375 dlen = j
376 break
377 }
378
379
380 val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
381 dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
382 switch dlen {
383 case 4:
384 dst[2] = dbuf[2]
385 dbuf[2] = 0
386 fallthrough
387 case 3:
388 dst[1] = dbuf[1]
389 if enc.strict && dbuf[2] != 0 {
390 return si, 0, CorruptInputError(si - 1)
391 }
392 dbuf[1] = 0
393 fallthrough
394 case 2:
395 dst[0] = dbuf[0]
396 if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
397 return si, 0, CorruptInputError(si - 2)
398 }
399 }
400
401 return si, dlen - 1, err
402 }
403
404
405
406
407
408 func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
409
410 n := len(src)
411 for n > 0 && rune(src[n-1]) == enc.padChar {
412 n--
413 }
414 n = decodedLen(n, NoPadding)
415
416 dst = slices.Grow(dst, n)
417 n, err := enc.Decode(dst[len(dst):][:n], src)
418 return dst[:len(dst)+n], err
419 }
420
421
422
423
424 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
425 dbuf := make([]byte, enc.DecodedLen(len(s)))
426 n, err := enc.Decode(dbuf, []byte(s))
427 return dbuf[:n], err
428 }
429
430 type decoder struct {
431 err error
432 readErr error
433 enc *Encoding
434 r io.Reader
435 buf [1024]byte
436 nbuf int
437 out []byte
438 outbuf [1024 / 4 * 3]byte
439 }
440
441 func (d *decoder) Read(p []byte) (n int, err error) {
442
443 if len(d.out) > 0 {
444 n = copy(p, d.out)
445 d.out = d.out[n:]
446 return n, nil
447 }
448
449 if d.err != nil {
450 return 0, d.err
451 }
452
453
454
455
456 for d.nbuf < 4 && d.readErr == nil {
457 nn := len(p) / 3 * 4
458 if nn < 4 {
459 nn = 4
460 }
461 if nn > len(d.buf) {
462 nn = len(d.buf)
463 }
464 nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
465 d.nbuf += nn
466 }
467
468 if d.nbuf < 4 {
469 if d.enc.padChar == NoPadding && d.nbuf > 0 {
470
471 var nw int
472 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
473 d.nbuf = 0
474 d.out = d.outbuf[:nw]
475 n = copy(p, d.out)
476 d.out = d.out[n:]
477 if n > 0 || len(p) == 0 && len(d.out) > 0 {
478 return n, nil
479 }
480 if d.err != nil {
481 return 0, d.err
482 }
483 }
484 d.err = d.readErr
485 if d.err == io.EOF && d.nbuf > 0 {
486 d.err = io.ErrUnexpectedEOF
487 }
488 return 0, d.err
489 }
490
491
492 nr := d.nbuf / 4 * 4
493 nw := d.nbuf / 4 * 3
494 if nw > len(p) {
495 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
496 d.out = d.outbuf[:nw]
497 n = copy(p, d.out)
498 d.out = d.out[n:]
499 } else {
500 n, d.err = d.enc.Decode(p, d.buf[:nr])
501 }
502 d.nbuf -= nr
503 copy(d.buf[:d.nbuf], d.buf[nr:])
504 return n, d.err
505 }
506
507
508
509
510
511
512
513 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
514 if len(src) == 0 {
515 return 0, nil
516 }
517
518
519
520
521 _ = enc.decodeMap
522
523 si := 0
524 for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
525 src2 := src[si : si+8]
526 if dn, ok := assemble64(
527 enc.decodeMap[src2[0]],
528 enc.decodeMap[src2[1]],
529 enc.decodeMap[src2[2]],
530 enc.decodeMap[src2[3]],
531 enc.decodeMap[src2[4]],
532 enc.decodeMap[src2[5]],
533 enc.decodeMap[src2[6]],
534 enc.decodeMap[src2[7]],
535 ); ok {
536 byteorder.BEPutUint64(dst[n:], dn)
537 n += 6
538 si += 8
539 } else {
540 var ninc int
541 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
542 n += ninc
543 if err != nil {
544 return n, err
545 }
546 }
547 }
548
549 for len(src)-si >= 4 && len(dst)-n >= 4 {
550 src2 := src[si : si+4]
551 if dn, ok := assemble32(
552 enc.decodeMap[src2[0]],
553 enc.decodeMap[src2[1]],
554 enc.decodeMap[src2[2]],
555 enc.decodeMap[src2[3]],
556 ); ok {
557 byteorder.BEPutUint32(dst[n:], dn)
558 n += 3
559 si += 4
560 } else {
561 var ninc int
562 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
563 n += ninc
564 if err != nil {
565 return n, err
566 }
567 }
568 }
569
570 for si < len(src) {
571 var ninc int
572 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
573 n += ninc
574 if err != nil {
575 return n, err
576 }
577 }
578 return n, err
579 }
580
581
582
583
584 func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
585
586
587 if n1|n2|n3|n4 == 0xff {
588 return 0, false
589 }
590 return uint32(n1)<<26 |
591 uint32(n2)<<20 |
592 uint32(n3)<<14 |
593 uint32(n4)<<8,
594 true
595 }
596
597
598
599
600 func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
601
602
603 if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
604 return 0, false
605 }
606 return uint64(n1)<<58 |
607 uint64(n2)<<52 |
608 uint64(n3)<<46 |
609 uint64(n4)<<40 |
610 uint64(n5)<<34 |
611 uint64(n6)<<28 |
612 uint64(n7)<<22 |
613 uint64(n8)<<16,
614 true
615 }
616
617 type newlineFilteringReader struct {
618 wrapped io.Reader
619 }
620
621 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
622 n, err := r.wrapped.Read(p)
623 for n > 0 {
624 offset := 0
625 for i, b := range p[:n] {
626 if b != '\r' && b != '\n' {
627 if i != offset {
628 p[offset] = b
629 }
630 offset++
631 }
632 }
633 if offset > 0 {
634 return offset, err
635 }
636
637 n, err = r.wrapped.Read(p)
638 }
639 return n, err
640 }
641
642
643 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
644 return &decoder{enc: enc, r: &newlineFilteringReader{r}}
645 }
646
647
648
649 func (enc *Encoding) DecodedLen(n int) int {
650 return decodedLen(n, enc.padChar)
651 }
652
653 func decodedLen(n int, padChar rune) int {
654 if padChar == NoPadding {
655
656 return n/4*3 + n%4*6/8
657 }
658
659 return n / 4 * 3
660 }
661
View as plain text