@@ -9,6 +9,7 @@ import os
99import subprocess
1010import sys
1111import tempfile
12+ from functools import reduce
1213
1314logging .basicConfig (format = '%(message)s' , level = logging .INFO )
1415
@@ -17,7 +18,8 @@ class RoundTripTask(object):
1718 def __init__ (self , input_filename , action , swift_syntax_test ,
1819 skip_bad_syntax ):
1920 assert action == '-round-trip-parse' or action == '-round-trip-lex'
20- assert type (input_filename ) == unicode
21+ if sys .version_info [0 ] < 3 :
22+ assert type (input_filename ) == unicode
2123 assert type (swift_syntax_test ) == str
2224
2325 assert os .path .isfile (input_filename ), \
@@ -51,9 +53,9 @@ class RoundTripTask(object):
5153 self .output_file .close ()
5254 self .stderr_file .close ()
5355
54- with open (self .output_file .name , 'r ' ) as stdout_in :
56+ with open (self .output_file .name , 'rb ' ) as stdout_in :
5557 self .stdout = stdout_in .read ()
56- with open (self .stderr_file .name , 'r ' ) as stderr_in :
58+ with open (self .stderr_file .name , 'rb ' ) as stderr_in :
5759 self .stderr = stderr_in .read ()
5860
5961 os .remove (self .output_file .name )
@@ -75,7 +77,7 @@ class RoundTripTask(object):
7577 raise RuntimeError ()
7678
7779 contents = '' .join (map (lambda l : l .decode ('utf-8' , errors = 'replace' ),
78- open (self .input_filename ).readlines ()))
80+ open (self .input_filename , 'rb' ).readlines ()))
7981 stdout_contents = self .stdout .decode ('utf-8' , errors = 'replace' )
8082
8183 if contents == stdout_contents :
@@ -92,7 +94,7 @@ def swift_files_in_dir(d):
9294 swift_files = []
9395 for root , dirs , files in os .walk (d ):
9496 for basename in files :
95- if not basename .decode ( 'utf-8' ). endswith ('.swift' ):
97+ if not basename .endswith ('.swift' ):
9698 continue
9799 abs_file = os .path .abspath (os .path .join (root , basename ))
98100 swift_files .append (abs_file )
@@ -149,7 +151,8 @@ This driver invokes swift-syntax-test using -round-trip-lex and
149151 all_input_files = [filename for dir_listing in dir_listings
150152 for filename in dir_listing ]
151153 all_input_files += args .individual_input_files
152- all_input_files = [f .decode ('utf-8' ) for f in all_input_files ]
154+ if sys .version_info [0 ] < 3 :
155+ all_input_files = [f .decode ('utf-8' ) for f in all_input_files ]
153156
154157 if len (all_input_files ) == 0 :
155158 logging .error ('No input files!' )
0 commit comments