From 2ae5112eb6430b81980b45b9d3a623bef351a91f Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 23 Jul 2025 22:52:01 +0200 Subject: [PATCH 1/3] refactoring --- onnx_extended/onnx2/cpu/_onnx2py.cpp | 153 ++++++++++----------------- 1 file changed, 55 insertions(+), 98 deletions(-) diff --git a/onnx_extended/onnx2/cpu/_onnx2py.cpp b/onnx_extended/onnx2/cpu/_onnx2py.cpp index ff32e911..5429f47d 100644 --- a/onnx_extended/onnx2/cpu/_onnx2py.cpp +++ b/onnx_extended/onnx2/cpu/_onnx2py.cpp @@ -8,6 +8,10 @@ namespace py = pybind11; #define PYDEFINE_PROTO(m, cls) \ py::class_(m, #cls, onnx2::cls::DOC).def(py::init<>()) +#define PYDEFINE_PROTO_WITH_SUBTYPES(m, cls, name) \ + py::class_ name(m, #cls, onnx2::cls::DOC); \ + name.def(py::init<>()); + #define PYADD_PROTO_SERIALIZATION(cls) \ def( \ "ParseFromString", \ @@ -158,137 +162,109 @@ PYBIND11_MODULE(_onnx2py, m) { bind_repeated_field(m, "RepeatedFieldDouble"); bind_repeated_field(m, "RepeatedFieldString"); - py::class_(m, "Message", "Message, base class for all onnx2 classes") - .def(py::init<>()); - py::enum_(m, "OperatorStatus", py::arithmetic()) .value("EXPERIMENTAL", onnx2::OperatorStatus::EXPERIMENTAL) .value("STABLE", onnx2::OperatorStatus::STABLE) .export_values(); + py::enum_(m, "DataType", py::arithmetic()) + .value("UNDEFINED", onnx2::TensorProto::DataType::UNDEFINED) + .value("FLOAT", onnx2::TensorProto::DataType::FLOAT) + .value("UINT8", onnx2::TensorProto::DataType::UINT8) + .value("INT8", onnx2::TensorProto::DataType::INT8) + .value("UINT16", onnx2::TensorProto::DataType::UINT16) + .value("INT16", onnx2::TensorProto::DataType::INT16) + .value("INT32", onnx2::TensorProto::DataType::INT32) + .value("INT64", onnx2::TensorProto::DataType::INT64) + .value("STRING", onnx2::TensorProto::DataType::STRING) + .value("BOOL", onnx2::TensorProto::DataType::BOOL) + .value("FLOAT16", onnx2::TensorProto::DataType::FLOAT16) + .value("DOUBLE", onnx2::TensorProto::DataType::DOUBLE) + .value("UINT32", onnx2::TensorProto::DataType::UINT32) + .value("UINT64", onnx2::TensorProto::DataType::UINT64) + .value("COMPLEX64", onnx2::TensorProto::DataType::COMPLEX64) + .value("COMPLEX128", onnx2::TensorProto::DataType::COMPLEX128) + .value("BFLOAT16", onnx2::TensorProto::DataType::BFLOAT16) + .value("FLOAT8E4M3FN", onnx2::TensorProto::DataType::FLOAT8E4M3FN) + .value("FLOAT8E4M3FNUZ", onnx2::TensorProto::DataType::FLOAT8E4M3FNUZ) + .value("FLOAT8E5M2", onnx2::TensorProto::DataType::FLOAT8E5M2) + .value("FLOAT8E5M2FNUZ", onnx2::TensorProto::DataType::FLOAT8E5M2FNUZ) + .value("UINT4", onnx2::TensorProto::DataType::UINT4) + .value("INT4", onnx2::TensorProto::DataType::INT4) + .value("FLOAT4E2M1", onnx2::TensorProto::DataType::FLOAT4E2M1) + .value("FLOAT8E8M0", onnx2::TensorProto::DataType::FLOAT8E8M0) + .export_values(); + + py::class_(m, "Message", "Message, base class for all onnx2 classes") + .def(py::init<>()); + PYDEFINE_PROTO(m, StringStringEntryProto) .PYFIELD(StringStringEntryProto, key) .PYFIELD(StringStringEntryProto, value) .PYADD_PROTO_SERIALIZATION(StringStringEntryProto); - bind_repeated_field(m, "RepeatedFieldStringStringEntryProto"); - py::class_(m, "OperatorSetIdProto", - "OperatorSetIdProto, opset definition") - .def(py::init<>()) + PYDEFINE_PROTO(m, OperatorSetIdProto) .PYFIELD(OperatorSetIdProto, domain) .PYFIELD(OperatorSetIdProto, version) .PYADD_PROTO_SERIALIZATION(OperatorSetIdProto); - bind_repeated_field(m, "RepeatedFieldOperatorSetIdProto"); - py::class_(m, "TensorAnnotation", - "TensorAnnotation, tensor annotation") - .def(py::init<>()) + PYDEFINE_PROTO(m, TensorAnnotation) .PYFIELD(TensorAnnotation, tensor_name) .PYFIELD(TensorAnnotation, quant_parameter_tensor_names) .PYADD_PROTO_SERIALIZATION(TensorAnnotation); - py::class_( - m, "IntIntListEntryProto", "IntIntListEntryProto, tensor annotation") - .def(py::init<>()) + PYDEFINE_PROTO(m, IntIntListEntryProto) .PYFIELD(IntIntListEntryProto, key) .PYFIELD(IntIntListEntryProto, value) .PYADD_PROTO_SERIALIZATION(IntIntListEntryProto); - bind_repeated_field(m, "RepeatedFieldIntIntListEntryProto"); - py::class_(m, "DeviceConfigurationProto", - "DeviceConfigurationProto") - .def(py::init<>()) + PYDEFINE_PROTO(m, DeviceConfigurationProto) .PYFIELD(DeviceConfigurationProto, name) .PYFIELD(DeviceConfigurationProto, num_devices) .PYFIELD(DeviceConfigurationProto, device) .PYADD_PROTO_SERIALIZATION(DeviceConfigurationProto); - py::class_(m, "SimpleShardedDimProto", - "SimpleShardedDimProto") - .def(py::init<>()) + PYDEFINE_PROTO(m, SimpleShardedDimProto) .PYFIELD_OPTIONAL_INT(SimpleShardedDimProto, dim_value) .PYFIELD(SimpleShardedDimProto, dim_param) .PYFIELD(SimpleShardedDimProto, num_shards) .PYADD_PROTO_SERIALIZATION(SimpleShardedDimProto); - bind_repeated_field(m, "RepeatedFieldSimpleShardedDimProto"); - py::class_(m, "ShardedDimProto", "ShardedDimProto") - .def(py::init<>()) + PYDEFINE_PROTO(m, ShardedDimProto) .PYFIELD(ShardedDimProto, axis) .PYFIELD(ShardedDimProto, simple_sharding) .PYADD_PROTO_SERIALIZATION(ShardedDimProto); - bind_repeated_field(m, "RepeatedFieldShardedDimProto"); - py::class_(m, "ShardingSpecProto", "ShardingSpecProto") - .def(py::init<>()) + PYDEFINE_PROTO(m, ShardingSpecProto) .PYFIELD(ShardingSpecProto, tensor_name) .PYFIELD(ShardingSpecProto, device) .PYFIELD(ShardingSpecProto, index_to_device_group_map) .PYFIELD(ShardingSpecProto, sharded_dim) .PYADD_PROTO_SERIALIZATION(ShardingSpecProto); - bind_repeated_field(m, "RepeatedFieldShardingSpecProto"); - py::class_( - m, "NodeDeviceConfigurationProto", "ShardingSpecNodeDeviceConfigurationProtoProto") - .def(py::init<>()) + PYDEFINE_PROTO(m, NodeDeviceConfigurationProto) .PYFIELD(NodeDeviceConfigurationProto, configuration_id) .PYFIELD(NodeDeviceConfigurationProto, sharding_spec) .PYFIELD_OPTIONAL_INT(NodeDeviceConfigurationProto, pipeline_stage) .PYADD_PROTO_SERIALIZATION(NodeDeviceConfigurationProto); - py::class_ cls_tensor_shape_proto( - m, "TensorShapeProto", "TensorShapeProto"); - - py::class_( - cls_tensor_shape_proto, "Dimension", "Dimension, an integer value or a string") - .def(py::init<>()) + PYDEFINE_PROTO_WITH_SUBTYPES(m, TensorShapeProto, cls_tensor_shape_proto); + PYDEFINE_PROTO(cls_tensor_shape_proto, TensorShapeProto::Dimension) .PYFIELD_OPTIONAL_INT(TensorShapeProto::Dimension, dim_value) .PYFIELD(TensorShapeProto::Dimension, dim_param) .PYFIELD(TensorShapeProto::Dimension, denotation) .PYADD_PROTO_SERIALIZATION(TensorShapeProto::Dimension); - bind_repeated_field(m, "RepeatedFieldDimension"); - - cls_tensor_shape_proto.def(py::init<>()) - .PYFIELD(TensorShapeProto, dim) + cls_tensor_shape_proto.PYFIELD(TensorShapeProto, dim) .PYADD_PROTO_SERIALIZATION(TensorShapeProto); - py::enum_(m, "DataType", py::arithmetic()) - .value("UNDEFINED", onnx2::TensorProto::DataType::UNDEFINED) - .value("FLOAT", onnx2::TensorProto::DataType::FLOAT) - .value("UINT8", onnx2::TensorProto::DataType::UINT8) - .value("INT8", onnx2::TensorProto::DataType::INT8) - .value("UINT16", onnx2::TensorProto::DataType::UINT16) - .value("INT16", onnx2::TensorProto::DataType::INT16) - .value("INT32", onnx2::TensorProto::DataType::INT32) - .value("INT64", onnx2::TensorProto::DataType::INT64) - .value("STRING", onnx2::TensorProto::DataType::STRING) - .value("BOOL", onnx2::TensorProto::DataType::BOOL) - .value("FLOAT16", onnx2::TensorProto::DataType::FLOAT16) - .value("DOUBLE", onnx2::TensorProto::DataType::DOUBLE) - .value("UINT32", onnx2::TensorProto::DataType::UINT32) - .value("UINT64", onnx2::TensorProto::DataType::UINT64) - .value("COMPLEX64", onnx2::TensorProto::DataType::COMPLEX64) - .value("COMPLEX128", onnx2::TensorProto::DataType::COMPLEX128) - .value("BFLOAT16", onnx2::TensorProto::DataType::BFLOAT16) - .value("FLOAT8E4M3FN", onnx2::TensorProto::DataType::FLOAT8E4M3FN) - .value("FLOAT8E4M3FNUZ", onnx2::TensorProto::DataType::FLOAT8E4M3FNUZ) - .value("FLOAT8E5M2", onnx2::TensorProto::DataType::FLOAT8E5M2) - .value("FLOAT8E5M2FNUZ", onnx2::TensorProto::DataType::FLOAT8E5M2FNUZ) - .value("UINT4", onnx2::TensorProto::DataType::UINT4) - .value("INT4", onnx2::TensorProto::DataType::INT4) - .value("FLOAT4E2M1", onnx2::TensorProto::DataType::FLOAT4E2M1) - .value("FLOAT8E8M0", onnx2::TensorProto::DataType::FLOAT8E8M0) - .export_values(); - - py::class_(m, "TensorProto", "TensorProto") - .def(py::init<>()) + PYDEFINE_PROTO(m, TensorProto) .SHORTEN_CODE(UNDEFINED) .SHORTEN_CODE(FLOAT) .SHORTEN_CODE(UINT8) @@ -376,51 +352,32 @@ PYBIND11_MODULE(_onnx2py, m) { "raw_data") .PYADD_PROTO_SERIALIZATION(TensorProto); - py::class_(m, "SparseTensorProto", - "SparseTensorProto, sparse tensor") - .def(py::init<>()) + PYDEFINE_PROTO(m, SparseTensorProto) .PYFIELD(SparseTensorProto, values) .PYFIELD(SparseTensorProto, indices) .PYFIELD(SparseTensorProto, dims) .PYADD_PROTO_SERIALIZATION(SparseTensorProto); - py::class_ cls_type_proto(m, "TypeProto", "TypeProto"); - - py::class_(cls_type_proto, "Tensor", - "Tensor, nested class of TypeProto") - .def(py::init<>()) + PYDEFINE_PROTO_WITH_SUBTYPES(m, TypeProto, cls_type_proto); + PYDEFINE_PROTO(cls_type_proto, TypeProto::Tensor) .PYFIELD_OPTIONAL_INT(TypeProto::Tensor, elem_type) .PYFIELD_OPTIONAL_PROTO(TypeProto::Tensor, shape) .PYADD_PROTO_SERIALIZATION(TypeProto::Tensor); - - py::class_( - cls_type_proto, "SparseTensor", "SparseTensor, nested class of TypeProto") - .def(py::init<>()) + PYDEFINE_PROTO(cls_type_proto, TypeProto::SparseTensor) .PYFIELD_OPTIONAL_INT(TypeProto::SparseTensor, elem_type) .PYFIELD_OPTIONAL_PROTO(TypeProto::SparseTensor, shape) .PYADD_PROTO_SERIALIZATION(TypeProto::SparseTensor); - - py::class_(cls_type_proto, "Sequence", - "Sequence, nested class of TypeProto") - .def(py::init<>()) + PYDEFINE_PROTO(cls_type_proto, TypeProto::Sequence) .PYFIELD_OPTIONAL_PROTO(TypeProto::Sequence, elem_type) .PYADD_PROTO_SERIALIZATION(TypeProto::Sequence); - - py::class_(cls_type_proto, "Optional", - "Optional, nested class of TypeProto") - .def(py::init<>()) + PYDEFINE_PROTO(cls_type_proto, TypeProto::Optional) .PYFIELD_OPTIONAL_PROTO(TypeProto::Optional, elem_type) .PYADD_PROTO_SERIALIZATION(TypeProto::Optional); - - py::class_(cls_type_proto, "Map", - "Map, nested class of TypeProto") - .def(py::init<>()) + PYDEFINE_PROTO(cls_type_proto, TypeProto::Map) .PYFIELD(TypeProto::Map, key_type) .PYFIELD_OPTIONAL_PROTO(TypeProto::Map, value_type) .PYADD_PROTO_SERIALIZATION(TypeProto::Map); - - cls_type_proto.def(py::init<>()) - .PYFIELD_OPTIONAL_PROTO(TypeProto, tensor_type) + cls_type_proto.PYFIELD_OPTIONAL_PROTO(TypeProto, tensor_type) .PYFIELD_OPTIONAL_PROTO(TypeProto, sequence_type) .PYFIELD_OPTIONAL_PROTO(TypeProto, map_type) .PYFIELD(TypeProto, denotation) From 8a9a1519100990cde3a85b4ab3168790d8c74795 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 24 Jul 2025 11:03:47 +0200 Subject: [PATCH 2/3] temp --- _unittests/ut_onnx2/test_onnx2.py | 123 ++++++++++++++++++++++- onnx_extended/onnx2/cpu/_onnx2py.cpp | 12 +-- onnx_extended/onnx2/cpu/onnx2.h | 26 +++-- onnx_extended/onnx2/cpu/stream_class.h | 102 ++++++++++++++++++- onnx_extended/onnx2/cpu/stream_class.hpp | 6 +- 5 files changed, 250 insertions(+), 19 deletions(-) diff --git a/_unittests/ut_onnx2/test_onnx2.py b/_unittests/ut_onnx2/test_onnx2.py index ddaf0986..1c652f20 100644 --- a/_unittests/ut_onnx2/test_onnx2.py +++ b/_unittests/ut_onnx2/test_onnx2.py @@ -531,7 +531,7 @@ def test_node_device_configuration_proto(self): b = ps.sharded_dim.add() b.axis = 3 c = b.simple_sharding.add() - c.dim_value = 4 + c.dim_value = 444 c.num_shards = 5 self.assertNotEmpty(p.sharding_spec) self.assertEqual(len(p.sharding_spec), 1) @@ -539,11 +539,132 @@ def test_node_device_configuration_proto(self): s = p.SerializeToString() p2 = x2.NodeDeviceConfigurationProto() p2.ParseFromString(s) + self.assertEqual(len(p2.sharding_spec), 1) + self.assertEqual(p.configuration_id, p2.configuration_id) + self.assertEqual(p.pipeline_stage, p2.pipeline_stage) + self.assertEqual( + p.sharding_spec[0].tensor_name, p2.sharding_spec[0].tensor_name + ) s2 = p2.SerializeToString() p0 = x.NodeDeviceConfigurationProto() p0.ParseFromString(s2) + self.assertEqual( + p.sharding_spec[0].index_to_device_group_map[0].SerializeToString(), + p2.sharding_spec[0] + .index_to_device_group_map[0] + .SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].index_to_device_group_map[0].SerializeToString(), + p0.sharding_spec[0] + .index_to_device_group_map[0] + .SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].index_to_device_group_map[1].SerializeToString(), + p2.sharding_spec[0] + .index_to_device_group_map[1] + .SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].axis, + p0.sharding_spec[0].sharded_dim[0].axis, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].axis, + p2.sharding_spec[0].sharded_dim[0].axis, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].num_shards, + p0.sharding_spec[0].sharded_dim[0].simple_sharding[0].num_shards, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].num_shards, + p2.sharding_spec[0].sharded_dim[0].simple_sharding[0].num_shards, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, + p2.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_param, + p2.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_param, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, + p2.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, 444 + ) + self.assertEqual( + p2.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, 444 + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].SerializeToString(), + p2.sharding_spec[0].sharded_dim[0].simple_sharding[0].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_param, + p0.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_param, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, + p0.sharding_spec[0].sharded_dim[0].simple_sharding[0].dim_value, + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].SerializeToString(), + p0.sharding_spec[0].sharded_dim[0].simple_sharding[0].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].simple_sharding[0].SerializeToString(), + p2.sharding_spec[0].sharded_dim[0].simple_sharding[0].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].SerializeToString(), + p0.sharding_spec[0].sharded_dim[0].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[0].SerializeToString(), + p2.sharding_spec[0].sharded_dim[0].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[1].SerializeToString(), + p0.sharding_spec[0].sharded_dim[1].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].sharded_dim[1].SerializeToString(), + p2.sharding_spec[0].sharded_dim[1].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].tensor_name, p2.sharding_spec[0].tensor_name + ) + self.assertEqual( + p.sharding_spec[0].tensor_name, p0.sharding_spec[0].tensor_name + ) + self.assertEqual( + list(p.sharding_spec[0].device), list(p2.sharding_spec[0].device) + ) + self.assertEqual( + list(p.sharding_spec[0].device), list(p0.sharding_spec[0].device) + ) + self.assertEqual( + p.sharding_spec[0].index_to_device_group_map[1].SerializeToString(), + p0.sharding_spec[0] + .index_to_device_group_map[1] + .SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].SerializeToString(), + p0.sharding_spec[0].SerializeToString(), + ) + self.assertEqual( + p.sharding_spec[0].SerializeToString(), + p2.sharding_spec[0].SerializeToString(), + ) self.assertEqual(p.SerializeToString(), p0.SerializeToString()) + self.assertEqual(p.SerializeToString(), p2.SerializeToString()) def test_tensor_proto_data_type(self): self.assertEqual(onnx2.TensorProto.UNDEFINED, onnx.TensorProto.UNDEFINED) diff --git a/onnx_extended/onnx2/cpu/_onnx2py.cpp b/onnx_extended/onnx2/cpu/_onnx2py.cpp index 5429f47d..1379fd8d 100644 --- a/onnx_extended/onnx2/cpu/_onnx2py.cpp +++ b/onnx_extended/onnx2/cpu/_onnx2py.cpp @@ -37,15 +37,15 @@ namespace py = pybind11; def_property( \ #name, \ [](onnx2::cls &self) -> py::object { \ - if (!self.name##_.has_value()) \ + if (!self.has_##name()) \ return py::none(); \ - return py::cast(*self.name##_, py::return_value_policy::reference); \ + return py::cast(self.name(), py::return_value_policy::reference); \ }, \ [](onnx2::cls &self, py::object obj) { \ if (obj.is_none()) { \ - self.name##_.reset(); \ + self.reset_##name(); \ } else if (py::isinstance(obj)) { \ - self.name##_ = obj.cast(); \ + self.set_##name(obj.cast()); \ } else { \ EXT_THROW("unexpected value type, unable to set '" #name "' for class '" #cls "'"); \ } \ @@ -228,7 +228,7 @@ PYBIND11_MODULE(_onnx2py, m) { .PYADD_PROTO_SERIALIZATION(DeviceConfigurationProto); PYDEFINE_PROTO(m, SimpleShardedDimProto) - .PYFIELD_OPTIONAL_INT(SimpleShardedDimProto, dim_value) + .PYFIELD(SimpleShardedDimProto, dim_value) .PYFIELD(SimpleShardedDimProto, dim_param) .PYFIELD(SimpleShardedDimProto, num_shards) .PYADD_PROTO_SERIALIZATION(SimpleShardedDimProto); @@ -256,7 +256,7 @@ PYBIND11_MODULE(_onnx2py, m) { PYDEFINE_PROTO_WITH_SUBTYPES(m, TensorShapeProto, cls_tensor_shape_proto); PYDEFINE_PROTO(cls_tensor_shape_proto, TensorShapeProto::Dimension) - .PYFIELD_OPTIONAL_INT(TensorShapeProto::Dimension, dim_value) + .PYFIELD(TensorShapeProto::Dimension, dim_value) .PYFIELD(TensorShapeProto::Dimension, dim_param) .PYFIELD(TensorShapeProto::Dimension, denotation) .PYADD_PROTO_SERIALIZATION(TensorShapeProto::Dimension); diff --git a/onnx_extended/onnx2/cpu/onnx2.h b/onnx_extended/onnx2/cpu/onnx2.h index 838b3d43..ab28db0d 100644 --- a/onnx_extended/onnx2/cpu/onnx2.h +++ b/onnx_extended/onnx2/cpu/onnx2.h @@ -30,10 +30,17 @@ FIELD(int32_t, num_devices, 2) FIELD_REPEATED(std::string, device, 3) END_PROTO() -BEGIN_PROTO(SimpleShardedDimProto) -FIELD_OPTIONAL(int64_t, dim_value, 1) -FIELD(std::string, dim_param, 2) +BEGIN_PROTO_NO_INIT(SimpleShardedDimProto) +inline SimpleShardedDimProto() : dim_filled_order(0) {} +FIELD_ONEOF_BEGIN(dim) +FIELD_ONEOF_0(int64_t, dim_value) +FIELD_ONEOF_0_STRING(dim_param) +FIELD_ONEOF_MIDDLE(dim) +FIELD_ONEOF_1(dim, int64_t, dim_value, 1) +FIELD_ONEOF_1_STRING(dim, dim_param, 2) +FIELD_ONEOF_END(dim, 2) FIELD(int64_t, num_shards, 3) +inline ~SimpleShardedDimProto() { clear_dim(); } END_PROTO() BEGIN_PROTO_NOINIT(ShardedDimProto) @@ -61,10 +68,17 @@ FIELD(int64_t, version, 2) END_PROTO() BEGIN_PROTO_NOINIT(TensorShapeProto) -BEGIN_PROTO(Dimension) -FIELD_OPTIONAL(int64_t, dim_value, 1) -FIELD(std::string, dim_param, 2) +BEGIN_PROTO_NOINIT(Dimension) +inline Dimension() : dim_filled_order(0) {} +FIELD_ONEOF_BEGIN(dim) +FIELD_ONEOF_0(int64_t, dim_value) +FIELD_ONEOF_0_STRING(dim_param) +FIELD_ONEOF_MIDDLE(dim) +FIELD_ONEOF_1(dim, int64_t, dim_value, 1) +FIELD_ONEOF_1_STRING(dim, dim_param, 2) +FIELD_ONEOF_END(dim, 2) FIELD(std::string, denotation, 3) +inline ~Dimension() { clear_dim(); } END_PROTO() inline TensorShapeProto() {} FIELD_REPEATED(Dimension, dim, 1) diff --git a/onnx_extended/onnx2/cpu/stream_class.h b/onnx_extended/onnx2/cpu/stream_class.h index 1a8fda11..4f9209bb 100644 --- a/onnx_extended/onnx2/cpu/stream_class.h +++ b/onnx_extended/onnx2/cpu/stream_class.h @@ -40,28 +40,107 @@ #if defined(FIELD) #pragma error("macro FIELD is already defined.") #endif + +#define FDEC_1 0 +#define FDEC_2 1 +#define FDEC_3 2 +#define FDEC(x) FDEC_##x +#define FCONCAT(a, b) a##b +#define FEXPAND_CONCAT(a, b) FCONCAT(a, b) + #define FIELD(type, name, order) \ public: \ inline type &name() { return name##_; } \ + inline const type &name() const { return name##_; } \ + inline void set_##name(const type &v) { name##_ = v; } \ inline bool has_##name() const { return _has_field_(name##_); } \ inline int order_##name() const { return order; } \ - type name##_; + type name##_; \ + using name##_t = type; + +#define FIELD_ONEOF_BEGIN(varname) \ +public: \ + union { + +#define FIELD_ONEOF_0(type, name) type name##_; +#define FIELD_ONEOF_0_STRING(name) std::string *name##_; + +#define FIELD_ONEOF_MIDDLE(varname) \ + } \ + ; \ + int32_t varname##_filled_order; \ + inline void clear_##varname##_0() {} + +#define FIELD_ONEOF_1(varname, type, name, order) \ + using name##_t = type; \ + inline type &name() { return name##_; } \ + inline const type &name() const { return name##_; } \ + inline void set_##name(const type &v) { \ + name##_ = v; \ + varname##_filled_order = order; \ + } \ + inline bool has_##name() const { return varname##_filled_order == order; } \ + inline void clear_##name() { \ + if (has_##name()) { \ + clear_one_of(name##_); \ + varname##_filled_order = 0; \ + } \ + } \ + inline void clear_##varname##_##order() { \ + clear_##name(); \ + FEXPAND_CONCAT(clear_##varname##_, FDEC(order))(); \ + } \ + inline int order_##name() const { return order; } + +#define FIELD_ONEOF_1_STRING(varname, name, order) \ + using name##_t = std::string; \ + inline type &name() { return has_##name() ? *name##_ : std::string(); } \ + inline const type &name() const { return has_##name() ? *name##_ : std::string(); } \ + inline void set_##name(const type &v) { \ + if (has_##name()) \ + *name##_ = v; \ + else \ + name##_ = new std::string(v); \ + varname##_filled_order = order; \ + } \ + inline bool has_##name() const { return varname##_filled_order == order; } \ + inline void clear_##name() { \ + if (has_##name()) { \ + if (&&name##_ != nullptr) { \ + delete name##_; \ + name##_ = nullptr; \ + } \ + varname##_filled_order = 0; \ + } \ + } \ + inline void clear_##varname##_##order() { \ + clear_##name(); \ + FEXPAND_CONCAT(clear_##varname##_, FDEC(order))(); \ + } \ + inline int order_##name() const { return order; } + +#define FIELD_ONEOF_END(varname, n) \ + inline void clear_##varname() { clear_##varname##_##n(); }; #define FIELD_REPEATED(type, name, order) \ public: \ inline utils::RepeatedField &name() { return name##_; } \ + inline void add_##name(type &&v) { name##_.emplace_back(v); } \ inline bool has_##name() const { return _has_field_(name##_) && !name##_.empty(); } \ inline int order_##name() const { return order; } \ inline bool packed_##name() const { return false; } \ - utils::RepeatedField name##_; + utils::RepeatedField name##_; \ + using name##_t = type; #define FIELD_REPEATED_PACKED(type, name, order) \ public: \ inline utils::RepeatedField &name() { return name##_; } \ + inline void add_##name(type &&v) { name##_.emplace_back(v); } \ inline bool has_##name() const { return _has_field_(name##_) && !name##_.empty(); } \ inline int order_##name() const { return order; } \ inline bool packed_##name() const { return true; } \ - utils::RepeatedField name##_; + utils::RepeatedField name##_; \ + using name##_t = type; #define FIELD_OPTIONAL(type, name, order) \ public: \ @@ -69,6 +148,20 @@ public: EXT_ENFORCE(name##_.has_value(), "Optional field '", #name, "' has no value."); \ return *name##_; \ } \ + inline const type &name() const { \ + EXT_ENFORCE(name##_.has_value(), "Optional field '", #name, "' has no value."); \ + return *name##_; \ + } \ + inline utils::OptionalField &name##_optional() { \ + EXT_ENFORCE(name##_.has_value(), "Optional field '", #name, "' has no value."); \ + return name##_; \ + } \ + inline const utils::OptionalField &name##_optional() const { \ + EXT_ENFORCE(name##_.has_value(), "Optional field '", #name, "' has no value."); \ + return name##_; \ + } \ + inline void set_##name(const type &v) { name##_ = v; } \ + inline void reset_##name() { name##_.reset(); } \ inline bool has_##name() const { return _has_field_(name##_); } \ inline int order_##name() const { return order; } \ utils::OptionalField name##_; \ @@ -78,6 +171,9 @@ namespace onnx2 { using utils::offset_t; +template inline void clear_one_of(T &) {} +template <> inline void clear_one_of(std::string &s) { s.clear(); } + template inline bool _has_field_(const T &) { return true; } template <> inline bool _has_field_(const std::string &field) { return !field.empty(); } template <> inline bool _has_field_(const utils::OptionalField &field) { diff --git a/onnx_extended/onnx2/cpu/stream_class.hpp b/onnx_extended/onnx2/cpu/stream_class.hpp index 7a4af3cb..2fd6190d 100644 --- a/onnx_extended/onnx2/cpu/stream_class.hpp +++ b/onnx_extended/onnx2/cpu/stream_class.hpp @@ -24,12 +24,12 @@ #define WRITE_FIELD(stream, name) \ if (has_##name()) { \ - write_field(stream, order_##name(), name##_); \ + write_field(stream, order_##name(), name()); \ } #define WRITE_ENUM_FIELD(stream, name) \ if (has_##name()) { \ - write_enum_field(stream, order_##name(), name##_); \ + write_enum_field(stream, order_##name(), name()); \ } #define WRITE_REPEATED_FIELD(stream, name) \ @@ -39,7 +39,7 @@ #define WRITE_OPTIONAL_PROTO_FIELD(stream, name) \ if (has_##name()) { \ - write_optional_proto_field(stream, order_##name(), name##_); \ + write_optional_proto_field(stream, order_##name(), name##_optional()); \ } ///////////// From 2535490bc78d5800362219701e9e61db9a53de43 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 24 Jul 2025 11:59:25 +0200 Subject: [PATCH 3/3] changes --- _unittests/ut_onnx2/test_onnx2.py | 4 +- onnx_extended/onnx2/cpu/onnx2.cpp | 8 ++-- onnx_extended/onnx2/cpu/onnx2.h | 6 +-- onnx_extended/onnx2/cpu/stream_class.h | 48 ++++++++++++++++++++---- onnx_extended/onnx2/cpu/stream_class.hpp | 31 +++++++++------ 5 files changed, 69 insertions(+), 28 deletions(-) diff --git a/_unittests/ut_onnx2/test_onnx2.py b/_unittests/ut_onnx2/test_onnx2.py index 1c652f20..aba66e77 100644 --- a/_unittests/ut_onnx2/test_onnx2.py +++ b/_unittests/ut_onnx2/test_onnx2.py @@ -432,7 +432,7 @@ def test_device_configuration_proto(self): def test_simple_shared_dim_proto(self): for x, x2 in [(onnx, onnx2), (onnx2, onnx)]: - with self.subTest(start=x.__name__): + with self.subTest(start=x.__name__, case="dim_value"): p = x.SimpleShardedDimProto() p.dim_value = 3 # p.dim_param = "rt" @@ -450,7 +450,7 @@ def test_simple_shared_dim_proto(self): self.assertEqual(p.SerializeToString(), p0.SerializeToString()) for x, x2 in [(onnx, onnx2), (onnx2, onnx)]: - with self.subTest(start=x.__name__): + with self.subTest(start=x.__name__, case="dim_param"): p = x.SimpleShardedDimProto() # p.dim_value = 3 p.dim_param = "rt" diff --git a/onnx_extended/onnx2/cpu/onnx2.cpp b/onnx_extended/onnx2/cpu/onnx2.cpp index 7690f1a6..1ddfb472 100644 --- a/onnx_extended/onnx2/cpu/onnx2.cpp +++ b/onnx_extended/onnx2/cpu/onnx2.cpp @@ -76,8 +76,8 @@ void SimpleShardedDimProto::SerializeToStream(utils::BinaryWriteStream &stream) void SimpleShardedDimProto::ParseFromStream(utils::BinaryStream &stream) { READ_BEGIN(stream, SimpleShardedDimProto) - READ_FIELD(stream, dim_value) - READ_FIELD(stream, dim_param) + READ_ONEOF_FIELD(stream, dim, dim_value) + READ_ONEOF_FIELD(stream, dim, dim_param) READ_FIELD(stream, num_shards) READ_END(stream, SimpleShardedDimProto) } @@ -160,8 +160,8 @@ void TensorShapeProto::Dimension::SerializeToStream(utils::BinaryWriteStream &st void TensorShapeProto::Dimension::ParseFromStream(utils::BinaryStream &stream) { READ_BEGIN(stream, TensorShapeProto::Dimension) - READ_FIELD(stream, dim_value) - READ_FIELD(stream, dim_param) + READ_ONEOF_FIELD(stream, dim, dim_value) + READ_ONEOF_FIELD(stream, dim, dim_param) READ_FIELD(stream, denotation) READ_END(stream, TensorShapeProto::Dimension) } diff --git a/onnx_extended/onnx2/cpu/onnx2.h b/onnx_extended/onnx2/cpu/onnx2.h index ab28db0d..afcafb67 100644 --- a/onnx_extended/onnx2/cpu/onnx2.h +++ b/onnx_extended/onnx2/cpu/onnx2.h @@ -30,8 +30,7 @@ FIELD(int32_t, num_devices, 2) FIELD_REPEATED(std::string, device, 3) END_PROTO() -BEGIN_PROTO_NO_INIT(SimpleShardedDimProto) -inline SimpleShardedDimProto() : dim_filled_order(0) {} +BEGIN_PROTO_NOINIT(SimpleShardedDimProto) FIELD_ONEOF_BEGIN(dim) FIELD_ONEOF_0(int64_t, dim_value) FIELD_ONEOF_0_STRING(dim_param) @@ -40,6 +39,7 @@ FIELD_ONEOF_1(dim, int64_t, dim_value, 1) FIELD_ONEOF_1_STRING(dim, dim_param, 2) FIELD_ONEOF_END(dim, 2) FIELD(int64_t, num_shards, 3) +inline SimpleShardedDimProto() : dim_param_(nullptr), dim_filled_order(0) {} inline ~SimpleShardedDimProto() { clear_dim(); } END_PROTO() @@ -69,7 +69,6 @@ END_PROTO() BEGIN_PROTO_NOINIT(TensorShapeProto) BEGIN_PROTO_NOINIT(Dimension) -inline Dimension() : dim_filled_order(0) {} FIELD_ONEOF_BEGIN(dim) FIELD_ONEOF_0(int64_t, dim_value) FIELD_ONEOF_0_STRING(dim_param) @@ -78,6 +77,7 @@ FIELD_ONEOF_1(dim, int64_t, dim_value, 1) FIELD_ONEOF_1_STRING(dim, dim_param, 2) FIELD_ONEOF_END(dim, 2) FIELD(std::string, denotation, 3) +inline Dimension() : dim_param_(nullptr), dim_filled_order(0) {} inline ~Dimension() { clear_dim(); } END_PROTO() inline TensorShapeProto() {} diff --git a/onnx_extended/onnx2/cpu/stream_class.h b/onnx_extended/onnx2/cpu/stream_class.h index 4f9209bb..b745b13c 100644 --- a/onnx_extended/onnx2/cpu/stream_class.h +++ b/onnx_extended/onnx2/cpu/stream_class.h @@ -2,6 +2,16 @@ #include "stream.h" +#define DEBUG_READ + +#if defined(DEBUG_READ) +#define DEBUG_PRINT(s) printf("%s\n", s); +#define DEBUG_PRINT2(s1, s2) printf("%s%s\n", s1, s2); +#else +#define DEBUG_PRINT(s) +#define DEBUG_PRINT2(s1, s2) +#endif + #define FIELD_VARINT 0 // #define FIELD_FIXED64 1 #define FIELD_FIXED_SIZE 2 @@ -69,7 +79,7 @@ public: } \ ; \ int32_t varname##_filled_order; \ - inline void clear_##varname##_0() {} + inline void clear_##varname##_0() { DEBUG_PRINT("clear " #varname "_0") } #define FIELD_ONEOF_1(varname, type, name, order) \ using name##_t = type; \ @@ -81,6 +91,7 @@ public: } \ inline bool has_##name() const { return varname##_filled_order == order; } \ inline void clear_##name() { \ + DEBUG_PRINT("clear " #name) \ if (has_##name()) { \ clear_one_of(name##_); \ varname##_filled_order = 0; \ @@ -94,19 +105,39 @@ public: #define FIELD_ONEOF_1_STRING(varname, name, order) \ using name##_t = std::string; \ - inline type &name() { return has_##name() ? *name##_ : std::string(); } \ - inline const type &name() const { return has_##name() ? *name##_ : std::string(); } \ - inline void set_##name(const type &v) { \ + inline std::string &name() { \ + if (has_##name() && name##_ != nullptr) { \ + DEBUG_PRINT("access " #name "_") \ + return *name##_; \ + } else { \ + static std::string _; \ + return _; \ + } \ + } \ + inline const std::string &name() const { \ + if (has_##name() && name##_ != nullptr) { \ + DEBUG_PRINT("access " #name "_") \ + return *name##_; \ + } else { \ + static std::string _; \ + return _; \ + } \ + } \ + inline void set_##name(const std::string &v) { \ if (has_##name()) \ *name##_ = v; \ - else \ + else { \ + DEBUG_PRINT("new " #name "_") \ name##_ = new std::string(v); \ + } \ varname##_filled_order = order; \ } \ inline bool has_##name() const { return varname##_filled_order == order; } \ inline void clear_##name() { \ + DEBUG_PRINT("clear " #name) \ if (has_##name()) { \ - if (&&name##_ != nullptr) { \ + if (name##_ != nullptr) { \ + DEBUG_PRINT("delete " #name "_") \ delete name##_; \ name##_ = nullptr; \ } \ @@ -120,7 +151,10 @@ public: inline int order_##name() const { return order; } #define FIELD_ONEOF_END(varname, n) \ - inline void clear_##varname() { clear_##varname##_##n(); }; + inline void clear_##varname() { \ + DEBUG_PRINT("clear " #varname "_" #n); \ + clear_##varname##_##n(); \ + }; #define FIELD_REPEATED(type, name, order) \ public: \ diff --git a/onnx_extended/onnx2/cpu/stream_class.hpp b/onnx_extended/onnx2/cpu/stream_class.hpp index 2fd6190d..8597ff95 100644 --- a/onnx_extended/onnx2/cpu/stream_class.hpp +++ b/onnx_extended/onnx2/cpu/stream_class.hpp @@ -8,16 +8,6 @@ #include #include -// #define DEBUG_READ - -#if defined(DEBUG_READ) -#define DEBUG_PRINT(s) printf("%s\n", s); -#define DEBUG_PRINT2(s1, s2) printf("%s%s\n", s1, s2); -#else -#define DEBUG_PRINT(s) -#define DEBUG_PRINT2(s1, s2) -#endif - ////////////// // macro write ////////////// @@ -69,6 +59,14 @@ DEBUG_PRINT(" - field " #name) \ } +#define READ_ONEOF_FIELD(stream, varname, name) \ + else if (static_cast(field_number.field_number) == order_##name()) { \ + DEBUG_PRINT(" + field " #name) \ + read_field(stream, field_number.wire_type, name##_, #name); \ + varname##_filled_order = order_##name(); \ + DEBUG_PRINT(" - field " #name) \ + } + #define READ_OPTIONAL_PROTO_FIELD(stream, name) \ else if (static_cast(field_number.field_number) == order_##name()) { \ DEBUG_PRINT(" + optional field " #name) \ @@ -286,13 +284,22 @@ void read_optional_proto_field(utils::BinaryStream &stream, int wire_type, } template <> -void read_field(utils::BinaryStream &stream, int wire_type, std::string &field, - const char *name) { +void read_field(utils::BinaryStream &stream, int wire_type, std::string &field, + const char *name) { EXT_ENFORCE(wire_type == FIELD_FIXED_SIZE, "unexpected wire_type=", wire_type, " for field '", name, "'"); field = stream.next_string(); } +template <> +void read_field(utils::BinaryStream &stream, int wire_type, std::string *&field, + const char *name) { + EXT_ENFORCE(wire_type == FIELD_FIXED_SIZE, "unexpected wire_type=", wire_type, " for field '", + name, "'"); + DEBUG_PRINT("new std::string") + field = new std::string(stream.next_string()); +} + template <> void read_field(utils::BinaryStream &stream, int wire_type, utils::OptionalField &field, const char *name) {