|
26 | 26 |
|
27 | 27 | from pyspark.conf import SparkConf # pylint: disable=import-error,wrong-import-position |
28 | 28 |
|
| 29 | +NEW_DBUTILS = False |
| 30 | + |
29 | 31 | with warnings.catch_warnings(): |
30 | 32 | warnings.simplefilter("ignore") |
31 | 33 | # Suppress py4j loading message on stderr by redirecting sys.stderr |
32 | 34 | stderr_orig = sys.stderr |
33 | 35 | sys.stderr = io.StringIO() |
34 | | - from PythonShell import get_existing_gateway, RemoteContext # pylint: disable=import-error |
| 36 | + try: |
| 37 | + # Up to DBR 7.x |
| 38 | + from PythonShell import get_existing_gateway, RemoteContext # pylint: disable=import-error,wrong-import-position |
| 39 | + except: |
| 40 | + # Above DBR 8.0 |
| 41 | + sys.path.insert(0, "/databricks/python_shell") |
| 42 | + sys.path.insert(0, "/databricks/python_shell/scripts/") |
| 43 | + |
| 44 | + from dbruntime.spark_connection import RemoteContext, get_existing_gateway # pylint: disable=import-error,wrong-import-position |
| 45 | + |
| 46 | + try: |
| 47 | + # up to DBR 8.2 |
| 48 | + from dbutils import DBUtils # pylint: disable=import-error,wrong-import-position |
| 49 | + NEW_DBUTILS = False |
| 50 | + except: |
| 51 | + # above DBR 8.3 |
| 52 | + from dbruntime.dbutils import DBUtils # pylint: disable=import-error,wrong-import-position |
| 53 | + NEW_DBUTILS = True |
35 | 54 |
|
36 | 55 | out = sys.stderr.getvalue() |
37 | 56 | # Restore sys.stderr |
|
40 | 59 | if not "py4j imported" in out: |
41 | 60 | print(out, file=sys.stderr) |
42 | 61 |
|
43 | | -from dbutils import DBUtils # pylint: disable=import-error,wrong-import-position |
44 | | - |
45 | 62 |
|
46 | 63 | class JobInfo: |
47 | 64 | """Job info class for Spark jobs |
@@ -109,8 +126,13 @@ def new_group_id(self): |
109 | 126 |
|
110 | 127 |
|
111 | 128 | class DbjlUtils: |
112 | | - def __init__(self, shell, entry_point): |
113 | | - self._dbutils = DBUtils(shell, entry_point) |
| 129 | + def __init__(self, shell, entry_point, sc, sqlContext, displayHTML): |
| 130 | + # ugly, but not possible to differentiate <= 8.2 from >= 8.3 |
| 131 | + try: |
| 132 | + self._dbutils = DBUtils(shell, entry_point) |
| 133 | + except: |
| 134 | + self._dbutils = DBUtils(shell, entry_point, sc, sqlContext, displayHTML) |
| 135 | + |
114 | 136 | self.fs = self._dbutils.fs |
115 | 137 | self.secrets = self._dbutils.secrets |
116 | 138 | self.notebook = Notebook() |
@@ -464,7 +486,7 @@ def get_config(self): # pylint: disable=unused-argument |
464 | 486 |
|
465 | 487 | # Initialize dbutils |
466 | 488 | # |
467 | | - dbutils = DbjlUtils(shell, entry_point) |
| 489 | + dbutils = DbjlUtils(shell, entry_point, sc, sqlContext, shell.displayHTML) |
468 | 490 |
|
469 | 491 | # Setting up Spark progress bar |
470 | 492 | # |
|
0 commit comments