|
1 | | -from typing import Iterator, List |
| 1 | +from base64 import b64encode, b64decode |
| 2 | +import binascii |
2 | 3 |
|
3 | 4 | __all__ = ["base64", "unbase64"] |
4 | 5 |
|
|
7 | 8 |
|
8 | 9 | def base64(s: str) -> Base64String: |
9 | 10 | """Encode the string s using Base64.""" |
10 | | - if isinstance(s, (bytearray, bytes)): |
11 | | - s = s.decode("unicode") # handle encoded string gracefully |
| 11 | + b: bytes = s.encode("utf-8") if isinstance(s, str) else s |
| 12 | + return b64encode(b).decode("ascii") |
12 | 13 |
|
13 | | - unicode_list = list(str_to_unicode_seq(s)) |
14 | | - length = len(unicode_list) |
15 | | - rest = length % 3 |
16 | | - result: List[str] = [] |
17 | | - extend = result.extend |
18 | 14 |
|
19 | | - for i in range(0, length - rest, 3): |
20 | | - a, b, c = unicode_list[i : i + 3] |
21 | | - result.extend( |
22 | | - ( |
23 | | - first_6_bits(a), |
24 | | - last_2_bits_and_first_4_bits(a, b), |
25 | | - last_4_bits_and_first_2_bits(b, c), |
26 | | - last_6_bits(c), |
27 | | - ) |
28 | | - ) |
29 | | - |
30 | | - if rest == 1: |
31 | | - a = unicode_list[-1] |
32 | | - extend((first_6_bits(a), last_2_bits_and_first_4_bits(a, 0), "==")) |
33 | | - elif rest == 2: |
34 | | - a, b = unicode_list[-2:] |
35 | | - extend( |
36 | | - ( |
37 | | - first_6_bits(a), |
38 | | - last_2_bits_and_first_4_bits(a, b), |
39 | | - last_4_bits_and_first_2_bits(b, 0), |
40 | | - "=", |
41 | | - ) |
42 | | - ) |
43 | | - |
44 | | - return "".join(result) |
45 | | - |
46 | | - |
47 | | -def first_6_bits(a: int) -> str: |
48 | | - return to_base_64_char(a >> 2 & 0x3F) |
49 | | - |
50 | | - |
51 | | -def last_2_bits_and_first_4_bits(a: int, b: int) -> str: |
52 | | - return to_base_64_char((a << 4 | b >> 4) & 0x3F) |
53 | | - |
54 | | - |
55 | | -def last_4_bits_and_first_2_bits(b: int, c: int) -> str: |
56 | | - return to_base_64_char((b << 2 | c >> 6) & 0x3F) |
57 | | - |
58 | | - |
59 | | -def last_6_bits(c: int) -> str: |
60 | | - return to_base_64_char(c & 0x3F) |
61 | | - |
62 | | - |
63 | | -def unbase64(s: str) -> str: |
| 15 | +def unbase64(s: Base64String) -> str: |
64 | 16 | """Decode the string s using Base64.""" |
65 | | - if isinstance(s, (bytearray, bytes)): |
66 | | - s = s.decode("ascii") # handle encoded string gracefully |
67 | | - |
68 | | - unicode_list: List[int] = [] |
69 | | - extend = unicode_list.extend |
70 | | - length = len(s) |
71 | | - |
72 | | - for i in range(0, length, 4): |
73 | | - try: |
74 | | - a, b, c, d = [from_base_64_char(char) for char in s[i : i + 4]] |
75 | | - except (KeyError, ValueError): |
76 | | - return "" # for compatibility |
77 | | - bitmap_24 = a << 18 | b << 12 | c << 6 | d |
78 | | - extend((bitmap_24 >> 16 & 0xFF, bitmap_24 >> 8 & 0xFF, bitmap_24 & 0xFF)) |
79 | | - |
80 | | - i = length - 1 |
81 | | - while i > 0 and s[i] == "=": |
82 | | - i -= 1 |
83 | | - unicode_list.pop() |
84 | | - |
85 | | - return "".join(unicode_list_to_str(unicode_list)) |
86 | | - |
87 | | - |
88 | | -b64_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" |
89 | | - |
90 | | -b64_character_map = {c: i for i, c in enumerate(b64_characters)} |
91 | | - |
92 | | - |
93 | | -def to_base_64_char(bit_map_6: int) -> str: |
94 | | - return b64_characters[bit_map_6] |
95 | | - |
96 | | - |
97 | | -def from_base_64_char(base_64_char: str) -> int: |
98 | | - return 0 if base_64_char == "=" else b64_character_map[base_64_char] |
99 | | - |
100 | | - |
101 | | -def str_to_unicode_seq(s: str) -> Iterator[int]: |
102 | | - for utf_char in s: |
103 | | - code = ord(utf_char) |
104 | | - if code < 0x80: |
105 | | - yield code |
106 | | - elif code < 0x800: |
107 | | - yield 0xC0 | code >> 6 |
108 | | - yield 0x80 | code & 0x3F |
109 | | - elif code < 0x10000: |
110 | | - yield 0xE0 | code >> 12 |
111 | | - yield 0x80 | code >> 6 & 0x3F |
112 | | - yield 0x80 | code & 0x3F |
113 | | - else: |
114 | | - yield 0xF0 | code >> 18 |
115 | | - yield 0x80 | code >> 12 & 0x3F |
116 | | - yield 0x80 | code >> 6 & 0x3F |
117 | | - yield 0x80 | code & 0x3F |
118 | | - |
119 | | - |
120 | | -def unicode_list_to_str(s: List[int]) -> Iterator[str]: |
121 | | - s.reverse() |
122 | | - next_code = s.pop |
123 | | - while s: |
124 | | - a = next_code() |
125 | | - if a & 0x80 == 0: |
126 | | - yield chr(a) |
127 | | - continue |
128 | | - b = next_code() |
129 | | - if a & 0xE0 == 0xC0: |
130 | | - yield chr((a & 0x1F) << 6 | b & 0x3F) |
131 | | - continue |
132 | | - c = next_code() |
133 | | - if a & 0xF0 == 0xE0: |
134 | | - yield chr((a & 0x0F) << 12 | (b & 0x3F) << 6 | c & 0x3F) |
135 | | - continue |
136 | | - d = next_code() |
137 | | - yield chr((a & 0x07) << 18 | (b & 0x3F) << 12 | (c & 0x3F) << 6 | d & 0x3F) |
| 17 | + try: |
| 18 | + b: bytes = s.encode("ascii") if isinstance(s, str) else s |
| 19 | + except UnicodeEncodeError: |
| 20 | + return "" |
| 21 | + try: |
| 22 | + return b64decode(b).decode("utf-8") |
| 23 | + except (binascii.Error, UnicodeDecodeError): |
| 24 | + return "" |
0 commit comments