2323JETPACK_VERSION = None
2424
2525__version__ = '1.2.0a0'
26-
26+ FX2TRT_ONLY = False
2727
2828def get_git_revision_short_hash () -> str :
2929 return subprocess .check_output (['git' , 'rev-parse' , '--short' , 'HEAD' ]).decode ('ascii' ).strip ()
3030
31+ if "--fx2trt-only" in sys .argv :
32+ FX2TRT_ONLY = True
33+ sys .argv .remove ("--fx2trt-only" )
3134
3235if "--release" not in sys .argv :
3336 __version__ = __version__ + "+" + get_git_revision_short_hash ()
@@ -138,11 +141,14 @@ def finalize_options(self):
138141 develop .finalize_options (self )
139142
140143 def run (self ):
141- global CXX11_ABI
142- build_libtorchtrt_pre_cxx11_abi (develop = True , cxx11_abi = CXX11_ABI )
143- gen_version_file ()
144- copy_libtorchtrt ()
145- develop .run (self )
144+ if FX2TRT_ONLY :
145+ develop .run (self )
146+ else :
147+ global CXX11_ABI
148+ build_libtorchtrt_pre_cxx11_abi (develop = True , cxx11_abi = CXX11_ABI )
149+ gen_version_file ()
150+ copy_libtorchtrt ()
151+ develop .run (self )
146152
147153
148154class InstallCommand (install ):
@@ -155,11 +161,14 @@ def finalize_options(self):
155161 install .finalize_options (self )
156162
157163 def run (self ):
158- global CXX11_ABI
159- build_libtorchtrt_pre_cxx11_abi (develop = False , cxx11_abi = CXX11_ABI )
160- gen_version_file ()
161- copy_libtorchtrt ()
162- install .run (self )
164+ if FX2TRT_ONLY :
165+ install .run (self )
166+ else :
167+ global CXX11_ABI
168+ build_libtorchtrt_pre_cxx11_abi (develop = False , cxx11_abi = CXX11_ABI )
169+ gen_version_file ()
170+ copy_libtorchtrt ()
171+ install .run (self )
163172
164173
165174class BdistCommand (bdist_wheel ):
@@ -254,6 +263,23 @@ def run(self):
254263 ] + (["-D_GLIBCXX_USE_CXX11_ABI=1" ] if CXX11_ABI else ["-D_GLIBCXX_USE_CXX11_ABI=0" ]),
255264 undef_macros = ["NDEBUG" ])
256265]
266+ if FX2TRT_ONLY :
267+ ext_modules = None
268+ packages = [
269+ "torch_tensorrt.fx" ,
270+ "torch_tensorrt.fx.converters" ,
271+ "torch_tensorrt.fx.passes" ,
272+ "torch_tensorrt.fx.tools" ,
273+ "torch_tensorrt.fx.tracer.acc_tracer" ,
274+ ]
275+ package_dir = {
276+ "torch_tensorrt.fx" : "torch_tensorrt/fx" ,
277+ "torch_tensorrt.fx.converters" : "torch_tensorrt/fx/converters" ,
278+ "torch_tensorrt.fx.passes" : "torch_tensorrt/fx/passes" ,
279+ "torch_tensorrt.fx.tools" : "torch_tensorrt/fx/tools" ,
280+ "torch_tensorrt.fx.tracer.acc_tracer" : "torch_tensorrt/fx/tracer/acc_tracer" ,
281+ }
282+
257283
258284with open ("README.md" , "r" , encoding = "utf-8" ) as fh :
259285 long_description = fh .read ()
@@ -282,7 +308,8 @@ def run(self):
282308 },
283309 zip_safe = False ,
284310 license = "BSD" ,
285- packages = find_packages (),
311+ packages = packages if FX2TRT_ONLY else find_packages (),
312+ package_dir = package_dir if FX2TRT_ONLY else {},
286313 classifiers = [
287314 "Development Status :: 5 - Stable" , "Environment :: GPU :: NVIDIA CUDA" ,
288315 "License :: OSI Approved :: BSD License" , "Intended Audience :: Developers" ,
@@ -311,4 +338,4 @@ def run(self):
311338 exclude_package_data = {
312339 '' : ['*.cpp' ],
313340 'torch_tensorrt' : ['csrc/*.cpp' ],
314- })
341+ }),
0 commit comments