Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 225e96d

Browse files
authored
Implement initial version of rewrite for df.getitem as attr (#601)
* Implement initial rewrite for df.getitem as attr * Minor fixes for df.getitem used as attribute
1 parent 47a3fd4 commit 225e96d

File tree

3 files changed

+107
-0
lines changed

3 files changed

+107
-0
lines changed

sdc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
# sdc.datatypes.hpat_pandas_dataframe_pass.sdc_nopython_pipeline_lite_register
6363

6464
import sdc.rewrites.dataframe_constructor
65+
import sdc.rewrites.dataframe_getitem_attribute
6566
import sdc.datatypes.hpat_pandas_functions
6667
import sdc.datatypes.hpat_pandas_dataframe_functions
6768
else:
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2020, Intel Corporation All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
#
10+
# Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
16+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
17+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
18+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
22+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
from numba.ir import Assign, Const, Expr, Var
28+
from numba.ir_utils import mk_unique_var
29+
from numba.rewrites import register_rewrite, Rewrite
30+
from numba.types import StringLiteral
31+
from numba.typing import signature
32+
33+
from sdc.config import config_pipeline_hpat_default
34+
from sdc.hiframes.pd_dataframe_type import DataFrameType
35+
36+
37+
if not config_pipeline_hpat_default:
38+
@register_rewrite('after-inference')
39+
class RewriteDataFrameGetItemAttr(Rewrite):
40+
"""
41+
Search for calls of df.attr and replace it with calls of df['attr']:
42+
$0.2 = getattr(value=df, attr=A) -> $const0.0 = const(str, A)
43+
$0.2 = static_getitem(value=df, index=A, index_var=$const0.0)
44+
"""
45+
46+
def match(self, func_ir, block, typemap, calltypes):
47+
self.func_ir = func_ir
48+
self.block = block
49+
self.typemap = typemap
50+
self.calltypes = calltypes
51+
self.getattrs = getattrs = set()
52+
for expr in block.find_exprs(op='getattr'):
53+
obj = typemap[expr.value.name]
54+
if not isinstance(obj, DataFrameType):
55+
continue
56+
if expr.attr in obj.columns:
57+
getattrs.add(expr)
58+
59+
return len(getattrs) > 0
60+
61+
def apply(self):
62+
new_block = self.block.copy()
63+
new_block.clear()
64+
for inst in self.block.body:
65+
if isinstance(inst, Assign) and inst.value in self.getattrs:
66+
const_assign = self._assign_const(inst)
67+
new_block.append(const_assign)
68+
69+
inst = self._assign_getitem(inst, index=const_assign.target)
70+
71+
new_block.append(inst)
72+
73+
return new_block
74+
75+
def _assign_const(self, inst, prefix='$const0'):
76+
"""Create constant from attribute of the instruction."""
77+
const_node = Const(inst.value.attr, inst.loc)
78+
const_var = Var(inst.target.scope, mk_unique_var(prefix), inst.loc)
79+
80+
self.func_ir._definitions[const_var.name] = [const_node]
81+
self.typemap[const_var.name] = StringLiteral(inst.value.attr)
82+
83+
return Assign(const_node, const_var, inst.loc)
84+
85+
def _assign_getitem(self, inst, index):
86+
"""Create getitem instruction from the getattr instruction."""
87+
new_expr = Expr.getitem(inst.value.value, index, inst.loc)
88+
new_inst = Assign(value=new_expr, target=inst.target, loc=inst.loc)
89+
90+
self.func_ir._definitions[inst.target] = [new_expr]
91+
self.calltypes[new_expr] = signature(
92+
self.typemap[inst.target.name],
93+
self.typemap[new_expr.value.name],
94+
self.typemap[new_expr.index.name]
95+
)
96+
97+
return new_inst

sdc/tests/test_dataframe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,15 @@ def test_impl(df):
13671367

13681368
pd.testing.assert_series_equal(sdc_func(df), test_impl(df))
13691369

1370+
def test_df_getitem_attr(self):
1371+
def test_impl(df):
1372+
return df.A
1373+
1374+
sdc_func = self.jit(test_impl)
1375+
df = gen_df(test_global_input_data_float64)
1376+
1377+
pd.testing.assert_series_equal(sdc_func(df), test_impl(df))
1378+
13701379
@skip_numba_jit
13711380
def test_isin_df1(self):
13721381
def test_impl(df, df2):

0 commit comments

Comments
 (0)