11""" Routines to support optional packages """
2+ from distutils .version import LooseVersion
3+
4+ from .externals .six import string_types , callable
25
36try :
47 import nose
1013from .tripwire import TripWire
1114
1215
13- def optional_package (name , trip_msg = None ):
16+ def _check_pkg_version (pkg , min_version ):
17+ # Default version checking function
18+ if isinstance (min_version , string_types ):
19+ min_version = LooseVersion (min_version )
20+ try :
21+ return min_version <= pkg .__version__
22+ except AttributeError :
23+ return False
24+
25+
26+ def optional_package (name , trip_msg = None , min_version = None ):
1427 """ Return package-like thing and module setup for package `name`
1528
1629 Parameters
@@ -19,8 +32,14 @@ def optional_package(name, trip_msg=None):
1932 package name
2033 trip_msg : None or str
2134 message to give when someone tries to use the return package, but we
22- could not import it, and have returned a TripWire object instead.
23- Default message if None.
35+ could not import it at an acceptable version, and have returned a
36+ TripWire object instead. Default message if None.
37+ min_version : None or str or LooseVersion or callable
38+ If None, do not specify a minimum version. If str, convert to a
39+ `distutils.version.LooseVersion`. If str or LooseVersion` compare to
40+ version of package `name` with ``min_version <= pkg.__version__``. If
41+ callable, accepts imported ``pkg`` as argument, and returns value of
42+ callable is True for acceptable package versions, False otherwise.
2443
2544 Returns
2645 -------
@@ -66,6 +85,12 @@ def optional_package(name, trip_msg=None):
6685 >>> hasattr(subpkg, 'dirname')
6786 True
6887 """
88+ if callable (min_version ):
89+ check_version = min_version
90+ elif min_version is None :
91+ check_version = lambda pkg : True
92+ else :
93+ check_version = lambda pkg : _check_pkg_version (pkg , min_version )
6994 # fromlist=[''] results in submodule being returned, rather than the top
7095 # level module. See help(__import__)
7196 fromlist = ['' ] if '.' in name else []
@@ -75,7 +100,15 @@ def optional_package(name, trip_msg=None):
75100 pass
76101 else : # import worked
77102 # top level module
78- return pkg , True , lambda : None
103+ if check_version (pkg ):
104+ return pkg , True , lambda : None
105+ # Failed version check
106+ if trip_msg is None :
107+ if callable (min_version ):
108+ trip_msg = 'Package %s fails version check' % min_version
109+ else :
110+ trip_msg = ('These functions need %s version >= %s' %
111+ (name , min_version ))
79112 if trip_msg is None :
80113 trip_msg = ('We need package %s for these functions, but '
81114 '``import %s`` raised an ImportError'
0 commit comments