99from typing import Sequence
1010from typing import Union
1111
12- DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n '
12+ DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-'
1313
1414
1515def has_coding (line ): # type: (bytes) -> bool
1616 if not line .strip ():
1717 return False
1818 return (
19- line .lstrip ()[0 :1 ] == b'#' and (
19+ line .lstrip ()[:1 ] == b'#' and (
2020 b'unicode' in line or
2121 b'encoding' in line or
2222 b'coding:' in line or
@@ -26,7 +26,7 @@ def has_coding(line): # type: (bytes) -> bool
2626
2727
2828class ExpectedContents (collections .namedtuple (
29- 'ExpectedContents' , ('shebang' , 'rest' , 'pragma_status' ),
29+ 'ExpectedContents' , ('shebang' , 'rest' , 'pragma_status' , 'ending' ),
3030)):
3131 """
3232 pragma_status:
@@ -47,6 +47,8 @@ def is_expected_pragma(self, remove): # type: (bool) -> bool
4747
4848def _get_expected_contents (first_line , second_line , rest , expected_pragma ):
4949 # type: (bytes, bytes, bytes, bytes) -> ExpectedContents
50+ ending = b'\r \n ' if first_line .endswith (b'\r \n ' ) else b'\n '
51+
5052 if first_line .startswith (b'#!' ):
5153 shebang = first_line
5254 potential_coding = second_line
@@ -55,7 +57,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
5557 potential_coding = first_line
5658 rest = second_line + rest
5759
58- if potential_coding == expected_pragma :
60+ if potential_coding . rstrip ( b' \r \n ' ) == expected_pragma :
5961 pragma_status = True # type: Optional[bool]
6062 elif has_coding (potential_coding ):
6163 pragma_status = None
@@ -64,7 +66,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
6466 rest = potential_coding + rest
6567
6668 return ExpectedContents (
67- shebang = shebang , rest = rest , pragma_status = pragma_status ,
69+ shebang = shebang , rest = rest , pragma_status = pragma_status , ending = ending ,
6870 )
6971
7072
@@ -93,7 +95,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
9395 f .truncate ()
9496 f .write (expected .shebang )
9597 if not remove :
96- f .write (expected_pragma )
98+ f .write (expected_pragma + expected . ending )
9799 f .write (expected .rest )
98100
99101 return 1
@@ -102,11 +104,7 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
102104def _normalize_pragma (pragma ): # type: (Union[bytes, str]) -> bytes
103105 if not isinstance (pragma , bytes ):
104106 pragma = pragma .encode ('UTF-8' )
105- return pragma .rstrip () + b'\n '
106-
107-
108- def _to_disp (pragma ): # type: (bytes) -> str
109- return pragma .decode ().rstrip ()
107+ return pragma .rstrip ()
110108
111109
112110def main (argv = None ): # type: (Optional[Sequence[str]]) -> int
@@ -117,7 +115,7 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
117115 parser .add_argument (
118116 '--pragma' , default = DEFAULT_PRAGMA , type = _normalize_pragma ,
119117 help = 'The encoding pragma to use. Default: {}' .format (
120- _to_disp ( DEFAULT_PRAGMA ),
118+ DEFAULT_PRAGMA . decode ( ),
121119 ),
122120 )
123121 parser .add_argument (
@@ -141,7 +139,7 @@ def main(argv=None): # type: (Optional[Sequence[str]]) -> int
141139 retv |= file_ret
142140 if file_ret :
143141 print (fmt .format (
144- pragma = _to_disp ( args .pragma ), filename = filename ,
142+ pragma = args .pragma . decode ( ), filename = filename ,
145143 ))
146144
147145 return retv
0 commit comments