Skip to content

Commit e809d49

Browse files
committed
Refactor the code
1 parent 8944470 commit e809d49

File tree

7 files changed

+79
-38
lines changed

7 files changed

+79
-38
lines changed

.github/workflows/python-publish.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ jobs:
3030
run: |
3131
python -m pip install --upgrade pip
3232
pip install build
33+
- name: Test with pytest
34+
run: |
35+
pip install pytest pytest-cov
36+
pytest
3337
- name: Build package
3438
run: python -m build --sdist
3539
- name: Publish package

pypmml/jvm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def __init__(self):
2626
self.java_path = None
2727

2828
# JVM options
29-
# Fix IllegalAccessError: cannot access class jdk.internal.math.FloatingDecimal
30-
self.java_opts = ["--add-exports=java.base/jdk.internal.math=ALL-UNNAMED"]
29+
self.java_opts = []
3130
java_opts = os.environ.get("JAVA_OPTS")
3231
if java_opts:
3332
self.java_opts.extend(java_opts.split())
@@ -144,7 +143,7 @@ def name(self):
144143

145144
class Py4jGateway(JVMGateway):
146145
"""Py4j"""
147-
from py4j.java_collections import JavaArray
146+
from py4j.java_collections import JavaArray, JavaList
148147
from py4j.java_gateway import JavaObject
149148
from py4j.protocol import Py4JJavaError
150149

@@ -179,7 +178,7 @@ def detach(self, java_object):
179178
Py4jGateway._gateway.detach(java_object)
180179

181180
def java2py(self, r):
182-
if isinstance(r, self.JavaArray):
181+
if isinstance(r, (self.JavaArray, self.JavaList)):
183182
return [self.java2py(x) for x in r]
184183
elif isinstance(r, self.JavaObject):
185184
cls_name = r.getClass().getName()

pypmml/metadata.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ def targetField(self):
5858

5959
@property
6060
def value(self):
61-
return self.call('value')
61+
obj = self.call('value')
62+
if obj is not None:
63+
return DataVal(obj).toVal
64+
else:
65+
return None
6266

6367
@property
6468
def ruleFeature(self):
@@ -88,3 +92,12 @@ def fieldNames(self):
8892
def get(self, name):
8993
fld = self.get(name)
9094
return Field(fld) if fld is not None else None
95+
96+
97+
class DataVal(JavaModelWrapper):
98+
def __init__(self, java_model):
99+
super(DataVal, self).__init__(java_model)
100+
101+
@property
102+
def toVal(self):
103+
return self.call('toVal')

pypmml/model.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from pypmml.base import JavaModelWrapper, PMMLContext
2020
from pypmml.jvm import PMMLError
2121
from pypmml.elements import Header
22-
from pypmml.metadata import Field, OutputField, DataDictionary
22+
from pypmml.metadata import Field, OutputField, DataDictionary, DataVal
23+
from pypmml.utils import is_nd_array, is_pandas_series, is_pandas_dataframe
2324

2425

2526
class Model(JavaModelWrapper):
@@ -115,33 +116,13 @@ def outputFields(self):
115116
@property
116117
def classes(self):
117118
"""The class labels in a classification model."""
118-
return self.call('classes')
119+
values = self.call('classes')
120+
return [DataVal(x).toVal for x in values]
119121

120122
def setSupplementOutput(self, value):
121123
self.call('setSupplementOutput', value)
122124
return self
123125

124-
def _is_nd_array(self, data):
125-
try:
126-
import numpy as np
127-
return isinstance(data, np.ndarray)
128-
except ImportError:
129-
return False
130-
131-
def _is_pandas_dataframe(self, data):
132-
try:
133-
import pandas as pd
134-
return isinstance(data, pd.DataFrame)
135-
except ImportError:
136-
return False
137-
138-
def _is_pandas_series(self, data):
139-
try:
140-
import pandas as pd
141-
return isinstance(data, pd.Series)
142-
except ImportError:
143-
return False
144-
145126
def predict(self, data):
146127
"""
147128
Predict values for a given data.
@@ -163,8 +144,7 @@ def predict(self, data):
163144
return self.call('predict', data)
164145
else:
165146
return []
166-
elif self._is_nd_array(data):
167-
import numpy as np
147+
elif is_nd_array(data):
168148
if data.ndim == 1:
169149
return self.call('predict', data.tolist())
170150
elif data.ndim == 2:
@@ -179,13 +159,13 @@ def predict(self, data):
179159
return [self.call('predict', record.tolist()) for record in data]
180160
else:
181161
raise PMMLError('Max 2 dimensions are supported')
182-
elif self._is_pandas_dataframe(data):
162+
elif is_pandas_dataframe(data):
183163
import pandas as pd
184164
from io import StringIO
185165
json_data = data.to_json(orient='split', index=False)
186166
result = self.call('predict', json_data)
187167
return pd.read_json(StringIO(result), orient='split')
188-
elif self._is_pandas_series(data):
168+
elif is_pandas_series(data):
189169
import pandas as pd
190170
record = data.to_dict()
191171
result = self.call('predict', record)

pypmml/test/test_model.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616

1717
import unittest
1818
from unittest import TestCase
19+
from os import path
1920

2021
from pypmml import Model, PMMLContext
2122

2223

2324
class ModelTestCase(TestCase):
25+
test_data_dir = path.join(path.dirname(__file__), 'resources', 'data')
26+
test_models_dir = path.join(path.dirname(__file__), 'resources', 'models')
2427

2528
def test_from_file(self):
2629
# The model is from here: http://dmg.org/pmml/pmml_examples/KNIME_PMML_4.1_Examples/single_iris_dectree.xml
27-
model = Model.fromFile('./resources/models/single_iris_dectree.xml')
30+
model_path = path.join(self.test_models_dir, 'single_iris_dectree.xml')
31+
model = Model.fromFile(model_path)
2832
self.assertEqual(model.version, '4.1')
2933

3034
app = model.header.application
@@ -115,7 +119,8 @@ def test_from_file(self):
115119
def test_pandas(self):
116120
try:
117121
import pandas as pd
118-
model = Model.load('./resources/models/single_iris_dectree.xml')
122+
model_path = path.join(self.test_models_dir, 'single_iris_dectree.xml')
123+
model = Model.load(model_path)
119124

120125
# Data in Series
121126
result = model.predict(pd.Series({'sepal_length': 5.1, 'sepal_width': 3.5, 'petal_length': 1.4, 'petal_width': 0.2}))
@@ -124,7 +129,8 @@ def test_pandas(self):
124129
self.assertEqual(result.get('node_id'), '1')
125130

126131
# Data in DataFrame
127-
data = pd.read_csv('./resources/data/Iris.csv')
132+
data_path = path.join(self.test_data_dir, 'Iris.csv')
133+
data = pd.read_csv(data_path)
128134
result = model.predict(data)
129135
self.assertEqual(result.iloc[0].get('predicted_class'), 'Iris-setosa')
130136
self.assertEqual(result.iloc[0].get('probability'), 1.0)
@@ -135,7 +141,8 @@ def test_pandas(self):
135141
def test_numpy(self):
136142
try:
137143
import numpy as np
138-
model = Model.load('./resources/models/single_iris_dectree.xml')
144+
model_path = path.join(self.test_models_dir, 'single_iris_dectree.xml')
145+
model = Model.load(model_path)
139146

140147
# Data in 1-D
141148
result = model.predict(np.array([5.1, 3.5, 1.4, 0.2]))
@@ -166,7 +173,7 @@ def test_numpy(self):
166173
pass
167174

168175
def test_load(self):
169-
file_path = './resources/models/single_iris_dectree.xml'
176+
file_path = path.join(self.test_models_dir, 'single_iris_dectree.xml')
170177
self.assertTrue(Model.load(file_path) is not None)
171178

172179
with open(file_path, 'rb') as f:

pypmml/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#
2+
# Copyright (c) 2024 AutoDeployAI
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
def is_nd_array(data):
18+
try:
19+
import numpy as np
20+
return isinstance(data, np.ndarray)
21+
except ImportError:
22+
return False
23+
24+
25+
def is_pandas_dataframe(data):
26+
try:
27+
import pandas as pd
28+
return isinstance(data, pd.DataFrame)
29+
except ImportError:
30+
return False
31+
32+
33+
def is_pandas_series(data):
34+
try:
35+
import pandas as pd
36+
return isinstance(data, pd.Series)
37+
except ImportError:
38+
return False

pypmml/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818

1919
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
2020
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
21-
__version__ = '1.5.2'
21+
__version__ = '1.5.3'

0 commit comments

Comments
 (0)