Skip to content

Commit 422156a

Browse files
author
Gourav Falaswal
committed
Update Python wrapperrs for remote cloud client
1 parent 9661dca commit 422156a

File tree

3 files changed

+43
-38
lines changed

3 files changed

+43
-38
lines changed

tensorflow/lite/remote/interpreter_wrapper2/interpreter_wrapper2.cc

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <memory>
2222
#include <sstream>
2323
#include <string>
24+
#include <utility>
2425

2526
#include "absl/memory/memory.h"
2627
#include "absl/strings/str_format.h"
@@ -58,15 +59,19 @@ using python_utils::PyDecrefDeleter;
5859

5960
PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
6061
void* pydata = malloc(size * sizeof(float));
61-
memcpy(pydata, data, size * sizeof(float));
62+
if (data != nullptr) {
63+
memcpy(pydata, data, size * sizeof(float));
64+
}
6265
PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
6366
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
6467
return obj;
6568
}
6669

6770
PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
6871
void* pydata = malloc(size * sizeof(int));
69-
memcpy(pydata, data, size * sizeof(int));
72+
if (data != nullptr) {
73+
memcpy(pydata, data, size * sizeof(int));
74+
}
7075
PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
7176
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
7277
return obj;
@@ -270,9 +275,9 @@ PyObject* InterpreterWrapper2::ResizeInputTensor(int i, PyObject* value,
270275
/*
271276
* This returns an int, and never returns an error
272277
*/
273-
int InterpreterWrapper2::NumTensors() const {
278+
int InterpreterWrapper2::NumTensors(int subgraph_index) const {
274279
if(local_exec)
275-
return lwrap->NumTensors();
280+
return lwrap->NumTensors(subgraph_index);
276281
try {
277282
roundtrip(tflite_num_tensors, proxy.id);
278283
return resp.count();
@@ -284,9 +289,9 @@ int InterpreterWrapper2::NumTensors() const {
284289
/*
285290
* This returns a string, and never returns an error
286291
*/
287-
std::string InterpreterWrapper2::TensorName(int i) const {
292+
std::string InterpreterWrapper2::TensorName(int i, int subgraph_index) const {
288293
if(local_exec)
289-
return lwrap->TensorName(i);
294+
return lwrap->TensorName(i, subgraph_index);
290295
try {
291296
roundtrip(tflite_tensor_name, proxy.id, i);
292297
return resp.name();
@@ -298,9 +303,9 @@ std::string InterpreterWrapper2::TensorName(int i) const {
298303
/*
299304
* Get TfLiteType from remote and convert to Numpy Type class
300305
*/
301-
PyObject* InterpreterWrapper2::TensorType(int i) const {
306+
PyObject* InterpreterWrapper2::TensorType(int i, int subgraph_index) const {
302307
if(local_exec)
303-
return lwrap->TensorType(i);
308+
return lwrap->TensorType(i, subgraph_index);
304309
try {
305310
roundtrip(tflite_tensor_type, proxy.id, i);
306311
if(resp.status())
@@ -319,9 +324,9 @@ PyObject* InterpreterWrapper2::TensorType(int i) const {
319324
return nullptr;
320325
}
321326

322-
PyObject* InterpreterWrapper2::TensorSize(int i) const {
327+
PyObject* InterpreterWrapper2::TensorSize(int i, int subgraph_index) const {
323328
if(local_exec)
324-
return lwrap->TensorSize(i);
329+
return lwrap->TensorSize(i, subgraph_index);
325330
try {
326331
roundtrip(tflite_tensor_size, proxy.id, i);
327332
if(resp.status())
@@ -337,9 +342,9 @@ PyObject* InterpreterWrapper2::TensorSize(int i) const {
337342
return nullptr;
338343
}
339344

340-
PyObject* InterpreterWrapper2::TensorSizeSignature(int i) const {
345+
PyObject* InterpreterWrapper2::TensorSizeSignature(int i, int subgraph_index) const {
341346
if(local_exec)
342-
return lwrap->TensorSizeSignature(i);
347+
return lwrap->TensorSizeSignature(i, subgraph_index);
343348
try {
344349
roundtrip(tflite_tensor_size_signature, proxy.id, i);
345350
if(resp.status())
@@ -359,9 +364,9 @@ PyObject* InterpreterWrapper2::TensorSizeSignature(int i) const {
359364
* Use <PyDictFromSparsityParam> from interpreter_wrapper.cc
360365
* and therefore convert out response to a TfLiteSparsity struct
361366
*/
362-
PyObject* InterpreterWrapper2::TensorSparsityParameters(int i) const {
367+
PyObject* InterpreterWrapper2::TensorSparsityParameters(int i, int subgraph_index) const {
363368
if(local_exec)
364-
return lwrap->TensorSparsityParameters(i);
369+
return lwrap->TensorSparsityParameters(i, subgraph_index);
365370
try {
366371
roundtrip(tflite_tensor_sparsity_parameters, proxy.id, i);
367372
if(resp.status())
@@ -412,9 +417,9 @@ PyObject* InterpreterWrapper2::TensorSparsityParameters(int i) const {
412417
return nullptr;
413418
}
414419

415-
PyObject* InterpreterWrapper2::TensorQuantization(int i) const {
420+
PyObject* InterpreterWrapper2::TensorQuantization(int i, int subgraph_index) const {
416421
if(local_exec)
417-
return lwrap->TensorQuantization(i);
422+
return lwrap->TensorQuantization(i, subgraph_index);
418423
try {
419424
roundtrip(tflite_tensor_quantization, proxy.id, i);
420425
if(resp.status())
@@ -430,9 +435,9 @@ PyObject* InterpreterWrapper2::TensorQuantization(int i) const {
430435
return nullptr;
431436
}
432437

433-
PyObject* InterpreterWrapper2::TensorQuantizationParameters(int i) const {
438+
PyObject* InterpreterWrapper2::TensorQuantizationParameters(int i, int subgraph_index) const {
434439
if(local_exec)
435-
return lwrap->TensorQuantizationParameters(i);
440+
return lwrap->TensorQuantizationParameters(i, subgraph_index);
436441
try {
437442
roundtrip(tflite_tensor_quantization_parameters, proxy.id, i);
438443
if(resp.status())

tensorflow/lite/remote/interpreter_wrapper2/interpreter_wrapper2.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ class InterpreterWrapper2 {
7878
PyObject* ResizeInputTensor(int i, PyObject* value, bool strict,
7979
int subgraph_index);
8080

81-
int NumTensors() const;
82-
std::string TensorName(int i) const;
83-
PyObject* TensorType(int i) const;
84-
PyObject* TensorSize(int i) const;
85-
PyObject* TensorSizeSignature(int i) const;
86-
PyObject* TensorSparsityParameters(int i) const;
81+
int NumTensors(int subgraph_index) const;
82+
std::string TensorName(int i, int subgraph_index) const;
83+
PyObject* TensorType(int i, int subgraph_index) const;
84+
PyObject* TensorSize(int i, int subgraph_index) const;
85+
PyObject* TensorSizeSignature(int i, int subgraph_index) const;
86+
PyObject* TensorSparsityParameters(int i, int subgraph_index) const;
8787
// Deprecated in favor of TensorQuantizationScales, below.
88-
PyObject* TensorQuantization(int i) const;
89-
PyObject* TensorQuantizationParameters(int i) const;
88+
PyObject* TensorQuantization(int i, int subgraph_index) const;
89+
PyObject* TensorQuantizationParameters(int i, int subgraph_index) const;
9090
PyObject* SetTensor(int i, PyObject* value, int subgraph_index);
9191
PyObject* GetTensor(int i,int subgraph_index) const;
9292
PyObject* ResetVariableTensors();

tensorflow/lite/remote/interpreter_wrapper2/interpreter_wrapper2_pybind11.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,33 +123,33 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
123123
.def("NumTensors", &InterpreterWrapper::NumTensors)
124124
.def("TensorName", &InterpreterWrapper::TensorName)
125125
.def("TensorType",
126-
[](const InterpreterWrapper& self, int i) {
127-
return tensorflow::PyoOrThrow(self.TensorType(i));
126+
[](const InterpreterWrapper& self, int i, int subgraph_index) {
127+
return tensorflow::PyoOrThrow(self.TensorType(i, subgraph_index));
128128
})
129129
.def("TensorSize",
130-
[](const InterpreterWrapper& self, int i) {
131-
return tensorflow::PyoOrThrow(self.TensorSize(i));
130+
[](const InterpreterWrapper& self, int i, int subgraph_index) {
131+
return tensorflow::PyoOrThrow(self.TensorSize(i, subgraph_index));
132132
})
133133
.def("TensorSizeSignature",
134-
[](const InterpreterWrapper& self, int i) {
135-
return tensorflow::PyoOrThrow(self.TensorSizeSignature(i));
134+
[](const InterpreterWrapper& self, int i, int subgraph_index) {
135+
return tensorflow::PyoOrThrow(self.TensorSizeSignature(i, subgraph_index));
136136
})
137137
.def("TensorSparsityParameters",
138-
[](const InterpreterWrapper& self, int i) {
139-
return tensorflow::PyoOrThrow(self.TensorSparsityParameters(i));
138+
[](const InterpreterWrapper& self, int i, int subgraph_index) {
139+
return tensorflow::PyoOrThrow(self.TensorSparsityParameters(i, subgraph_index));
140140
})
141141
.def(
142142
"TensorQuantization",
143-
[](const InterpreterWrapper& self, int i) {
144-
return tensorflow::PyoOrThrow(self.TensorQuantization(i));
143+
[](const InterpreterWrapper& self, int i, int subgraph_index) {
144+
return tensorflow::PyoOrThrow(self.TensorQuantization(i, subgraph_index));
145145
},
146146
R"pbdoc(
147147
Deprecated in favor of TensorQuantizationParameters.
148148
)pbdoc")
149149
.def(
150150
"TensorQuantizationParameters",
151-
[](InterpreterWrapper& self, int i) {
152-
return tensorflow::PyoOrThrow(self.TensorQuantizationParameters(i));
151+
[](InterpreterWrapper& self, int i, int subgraph_index) {
152+
return tensorflow::PyoOrThrow(self.TensorQuantizationParameters(i, subgraph_index));
153153
})
154154
.def(
155155
"SetTensor",

0 commit comments

Comments
 (0)