1313
1414
1515class Expander :
16+ local_include = re .compile (
17+ r'#include\s*"([a-z_]*(|.hpp))"\s*' )
1618 atcoder_include = re .compile (
17- '#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
19+ r '#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
1820
19- include_guard = re .compile ('#.*ATCODER_[A-Z_]*_HPP' )
21+ include_guard = re .compile (r '#.*ATCODER_[A-Z_]*_HPP' )
2022
2123 def is_ignored_line (self , line ) -> bool :
2224 if self .include_guard .match (line ):
@@ -30,27 +32,24 @@ def is_ignored_line(self, line) -> bool:
3032 def __init__ (self , lib_paths : List [Path ]):
3133 self .lib_paths = lib_paths
3234
33- included = set () # type: Set[str ]
35+ included = set () # type: Set[Path ]
3436
35- def find_acl (self , acl_name : str ) -> Optional [ Path ] :
37+ def find_acl (self , acl_name : str ) -> Path :
3638 for lib_path in self .lib_paths :
3739 path = lib_path / acl_name
3840 if path .exists ():
3941 return path
40- return None
42+ logger .error ('cannot find: {}' .format (acl_name ))
43+ raise FileNotFoundError ()
4144
42- def expand_acl (self , acl_name : str ) -> List [str ]:
43- if acl_name in self .included :
44- logger .info ('already included: {}' .format (acl_name ))
45+ def expand_acl (self , acl_file_path : Path ) -> List [str ]:
46+ if acl_file_path in self .included :
47+ logger .info ('already included: {}' .format (acl_file_path . name ))
4548 return []
46- self .included .add (acl_name )
47- logger .info ('include: {}' .format (acl_name ))
48- acl_path = self .find_acl (acl_name )
49- if not acl_path :
50- logger .warning ('cannot find: {}' .format (acl_name ))
51- raise FileNotFoundError ()
49+ self .included .add (acl_file_path )
50+ logger .info ('include: {}' .format (acl_file_path .name ))
5251
53- acl_source = open (str (acl_path )).read ()
52+ acl_source = open (str (acl_file_path )).read ()
5453
5554 result = [] # type: List[str]
5655 for line in acl_source .splitlines ():
@@ -59,7 +58,14 @@ def expand_acl(self, acl_name: str) -> List[str]:
5958
6059 m = self .atcoder_include .match (line )
6160 if m :
62- result .extend (self .expand_acl (m .group (1 )))
61+ name = m .group (1 )
62+ result .extend (self .expand_acl (self .find_acl (name )))
63+ continue
64+
65+ m = self .local_include .match (line )
66+ if m :
67+ name = m .group (1 )
68+ result .extend (self .expand_acl (acl_file_path .parent / name ))
6369 continue
6470
6571 result .append (line )
@@ -71,10 +77,11 @@ def expand(self, source: str) -> str:
7177 result = [] # type: List[str]
7278 for line in source .splitlines ():
7379 m = self .atcoder_include .match (line )
74-
7580 if m :
76- result .extend (self .expand_acl (m .group (1 )))
81+ acl_path = self .find_acl (m .group (1 ))
82+ result .extend (self .expand_acl (acl_path ))
7783 continue
84+
7885 result .append (line )
7986 return '\n ' .join (result )
8087
0 commit comments