1919from pypmml .base import JavaModelWrapper , PMMLContext
2020from pypmml .jvm import PMMLError
2121from pypmml .elements import Header
22- from pypmml .metadata import Field , OutputField , DataDictionary
22+ from pypmml .metadata import Field , OutputField , DataDictionary , DataVal
23+ from pypmml .utils import is_nd_array , is_pandas_series , is_pandas_dataframe
2324
2425
2526class Model (JavaModelWrapper ):
@@ -115,33 +116,13 @@ def outputFields(self):
115116 @property
116117 def classes (self ):
117118 """The class labels in a classification model."""
118- return self .call ('classes' )
119+ values = self .call ('classes' )
120+ return [DataVal (x ).toVal for x in values ]
119121
120122 def setSupplementOutput (self , value ):
121123 self .call ('setSupplementOutput' , value )
122124 return self
123125
124- def _is_nd_array (self , data ):
125- try :
126- import numpy as np
127- return isinstance (data , np .ndarray )
128- except ImportError :
129- return False
130-
131- def _is_pandas_dataframe (self , data ):
132- try :
133- import pandas as pd
134- return isinstance (data , pd .DataFrame )
135- except ImportError :
136- return False
137-
138- def _is_pandas_series (self , data ):
139- try :
140- import pandas as pd
141- return isinstance (data , pd .Series )
142- except ImportError :
143- return False
144-
145126 def predict (self , data ):
146127 """
147128 Predict values for a given data.
@@ -163,8 +144,7 @@ def predict(self, data):
163144 return self .call ('predict' , data )
164145 else :
165146 return []
166- elif self ._is_nd_array (data ):
167- import numpy as np
147+ elif is_nd_array (data ):
168148 if data .ndim == 1 :
169149 return self .call ('predict' , data .tolist ())
170150 elif data .ndim == 2 :
@@ -179,13 +159,13 @@ def predict(self, data):
179159 return [self .call ('predict' , record .tolist ()) for record in data ]
180160 else :
181161 raise PMMLError ('Max 2 dimensions are supported' )
182- elif self . _is_pandas_dataframe (data ):
162+ elif is_pandas_dataframe (data ):
183163 import pandas as pd
184164 from io import StringIO
185165 json_data = data .to_json (orient = 'split' , index = False )
186166 result = self .call ('predict' , json_data )
187167 return pd .read_json (StringIO (result ), orient = 'split' )
188- elif self . _is_pandas_series (data ):
168+ elif is_pandas_series (data ):
189169 import pandas as pd
190170 record = data .to_dict ()
191171 result = self .call ('predict' , record )
0 commit comments