From 2888208adfccbed31be0aa21403f63fcaad9c32f Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Sun, 8 Sep 2019 09:59:29 +0200 Subject: [PATCH 1/7] feat: tool to dump the onnx structure for debugging --- internal/tools/dump/main.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 internal/tools/dump/main.go diff --git a/internal/tools/dump/main.go b/internal/tools/dump/main.go new file mode 100644 index 00000000..b91496b5 --- /dev/null +++ b/internal/tools/dump/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "io/ioutil" + "log" + "os" + + "github.com/davecgh/go-spew/spew" + "github.com/owulveryck/onnx-go/internal/pb-onnx" +) + +func main() { + onnxFile := os.Args[1] + b, err := ioutil.ReadFile(onnxFile) + if err != nil { + log.Fatal(err) + } + var m pb.ModelProto + err = m.XXX_Unmarshal(b) + if err != nil { + log.Fatal(err) + } + + scs := spew.ConfigState{ + Indent: "\t", + DisablePointerAddresses: true, + } + scs.Dump(m) +} From 179f007aff7119333bcba67a3bd548bb6df2094c Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Sun, 8 Sep 2019 10:47:32 +0200 Subject: [PATCH 2/7] feat: use litter for a better pretty-printing --- internal/tools/dump/main.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/internal/tools/dump/main.go b/internal/tools/dump/main.go index b91496b5..8889460d 100644 --- a/internal/tools/dump/main.go +++ b/internal/tools/dump/main.go @@ -5,8 +5,8 @@ import ( "log" "os" - "github.com/davecgh/go-spew/spew" "github.com/owulveryck/onnx-go/internal/pb-onnx" + "github.com/sanity-io/litter" ) func main() { @@ -20,10 +20,6 @@ func main() { if err != nil { log.Fatal(err) } + litter.Dump(m) - scs := spew.ConfigState{ - Indent: "\t", - DisablePointerAddresses: true, - } - scs.Dump(m) } From 4a571eeb2e6d8a60e8e79633ad0baacd44847ce6 Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Sun, 8 Sep 2019 19:52:29 +0200 Subject: [PATCH 3/7] feat: placeholder for testing issue 120 --- self_node_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 self_node_test.go diff --git a/self_node_test.go b/self_node_test.go new file mode 100644 index 00000000..37907385 --- /dev/null +++ b/self_node_test.go @@ -0,0 +1,95 @@ +package onnx + +import ( + "testing" + + pb "github.com/owulveryck/onnx-go/internal/pb-onnx" +) + +func TestDecodeProto_self(t *testing.T) { + _ = pb.ModelProto{ + IrVersion: 5, + OpsetImport: []*pb.OperatorSetIdProto{ + &pb.OperatorSetIdProto{ + Domain: "", + Version: 7, + }, + }, + ProducerName: "tf2onnx", + ProducerVersion: "1.5.3", + Domain: "", + ModelVersion: 0, + DocString: "", + Graph: &pb.GraphProto{ + Node: []*pb.NodeProto{ + &pb.NodeProto{ + Input: []string{ + "x:0", + "x:0", + }, + Output: []string{ + "mul:0", + }, + Name: "mul", + OpType: "Mul", + Domain: "", + Attribute: nil, + DocString: "", + }, + }, + Name: "tf2onnx", + Initializer: nil, + DocString: "converted from ./model_nowind_test/export/", + Input: []*pb.ValueInfoProto{ + &pb.ValueInfoProto{ + Name: "x:0", + Type: &pb.TypeProto{ + Value: &pb.TypeProto_TensorType{ + TensorType: &pb.TypeProto_Tensor{ + ElemType: 1, + Shape: &pb.TensorShapeProto{ + Dim: []*pb.TensorShapeProto_Dimension{ + &pb.TensorShapeProto_Dimension{ + Value: &pb.TensorShapeProto_Dimension_DimValue{ + DimValue: 1, + }, + Denotation: "", + }, + }, + }, + }, + }, + Denotation: "", + }, + DocString: "", + }, + }, + Output: []*pb.ValueInfoProto{ + &pb.ValueInfoProto{ + Name: "mul:0", + Type: &pb.TypeProto{ + Value: &pb.TypeProto_TensorType{ + TensorType: &pb.TypeProto_Tensor{ + ElemType: 1, + Shape: &pb.TensorShapeProto{ + Dim: []*pb.TensorShapeProto_Dimension{ + &pb.TensorShapeProto_Dimension{ + Value: &pb.TensorShapeProto_Dimension_DimValue{ + DimValue: 1, + }, + Denotation: "", + }, + }, + }, + }, + }, + Denotation: "", + }, + DocString: "", + }, + }, + ValueInfo: nil, + }, + MetadataProps: nil, + } +} From fd651b09f9f52618112751da4eb3ba5e3a3efd24 Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Mon, 9 Sep 2019 16:42:27 +0200 Subject: [PATCH 4/7] feat: test the self edge --- self_node_test.go | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/self_node_test.go b/self_node_test.go index 37907385..c7b0504f 100644 --- a/self_node_test.go +++ b/self_node_test.go @@ -4,10 +4,11 @@ import ( "testing" pb "github.com/owulveryck/onnx-go/internal/pb-onnx" + "gonum.org/v1/gonum/graph" ) func TestDecodeProto_self(t *testing.T) { - _ = pb.ModelProto{ + input := &pb.ModelProto{ IrVersion: 5, OpsetImport: []*pb.OperatorSetIdProto{ &pb.OperatorSetIdProto{ @@ -92,4 +93,28 @@ func TestDecodeProto_self(t *testing.T) { }, MetadataProps: nil, } + backend := newTestBackend() + + m := NewModel(backend) + err := m.decodeProto(input) + if err != nil { + t.Fatal(err) + } + edges := backend.WeightedEdges() + if edges.Len() != 2 { + t.Fatal("expected 2 weighted edges") + } + ee := make([]graph.WeightedEdge, 2) + for i := 0; edges.Next(); i++ { + ee[i] = edges.WeightedEdge() + } + for i := 0; i < len(ee); i++ { + if ee[i].From() == ee[i].To() && ee[i].Weight() == self { + ee = ee[:len(ee)-1] + } + if ee[i].From() != ee[i].To() && ee[i].Weight() == 1 { + ee = ee[:len(ee)-1] + } + } + } From 62e8f1c5fc30e991a7d2b6a20327adaaf36ceae4 Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Thu, 12 Sep 2019 12:44:20 +0200 Subject: [PATCH 5/7] feat: add the WeightedEdges method to the test backend --- dummy_backend_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dummy_backend_test.go b/dummy_backend_test.go index a50429a3..6d5c6880 100644 --- a/dummy_backend_test.go +++ b/dummy_backend_test.go @@ -2,7 +2,6 @@ package onnx import ( "fmt" - "math" "gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph/encoding" @@ -91,7 +90,7 @@ func (n *nodeTest) ApplyTensor(t tensor.Tensor) error { // NewSimpleGraph ... func newTestBackend() *testBackend { return &testBackend{ - g: simple.NewWeightedDirectedGraph(math.MaxFloat64, -1), + g: simple.NewWeightedDirectedGraph(self, -1), } } @@ -107,6 +106,7 @@ func (g *testBackend) AddNode(n graph.Node) { g.g.AddNode(n) } + func (g *testBackend) NewNode() graph.Node { n := g.g.NewNode() return &nodeTest{ @@ -151,3 +151,8 @@ func (g *testBackend) ApplyOperation(_ Operation, _ ...graph.Node) error { func (g *testBackend) WeightedEdge(uid, vid int64) graph.WeightedEdge { return g.g.WeightedEdge(uid, vid) } + +// WeightedEdges returns all the weighted edges in the graph. +func (g *testBackend) WeightedEdges() graph.WeightedEdges { + return g.g.WeightedEdges() +} From 74e721879d331fe6e120b04c0bebfa5e1fe3c376 Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Thu, 12 Sep 2019 13:25:21 +0200 Subject: [PATCH 6/7] feat(wip):: introducting an exported value for SelfEdge --- backend.go | 5 +++++ dummy_backend_test.go | 2 +- self_node_test.go | 2 +- utils_test.go | 20 ++++++++++++++++++++ weighed_graph_test.go | 6 ++---- 5 files changed, 29 insertions(+), 6 deletions(-) create mode 100644 utils_test.go diff --git a/backend.go b/backend.go index 898fff40..e8e51193 100644 --- a/backend.go +++ b/backend.go @@ -1,9 +1,14 @@ package onnx import ( + "math" + "gonum.org/v1/gonum/graph" ) +// SelfEdge is the weight of a self edge in the graph +const SelfEdge = math.MaxFloat64 + // Backend represent any backend able to receive a computation graph type Backend interface { OperationCarrier diff --git a/dummy_backend_test.go b/dummy_backend_test.go index 6d5c6880..6ce30666 100644 --- a/dummy_backend_test.go +++ b/dummy_backend_test.go @@ -90,7 +90,7 @@ func (n *nodeTest) ApplyTensor(t tensor.Tensor) error { // NewSimpleGraph ... func newTestBackend() *testBackend { return &testBackend{ - g: simple.NewWeightedDirectedGraph(self, -1), + g: simple.NewWeightedDirectedGraph(SelfEdge, -1), } } diff --git a/self_node_test.go b/self_node_test.go index c7b0504f..e4a82ba0 100644 --- a/self_node_test.go +++ b/self_node_test.go @@ -109,7 +109,7 @@ func TestDecodeProto_self(t *testing.T) { ee[i] = edges.WeightedEdge() } for i := 0; i < len(ee); i++ { - if ee[i].From() == ee[i].To() && ee[i].Weight() == self { + if ee[i].From() == ee[i].To() && ee[i].Weight() == SelfEdge { ee = ee[:len(ee)-1] } if ee[i].From() != ee[i].To() && ee[i].Weight() == 1 { diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 00000000..642b2c86 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,20 @@ +package onnx + +import "testing" + +func TestContains(t *testing.T) { + table := []string{"a", "b", "c"} + ok := contains("a", table) + if !ok { + t.Fail() + } + ok = contains("z", table) + if ok { + t.Fail() + } + table = []string{"a", "a", "b", "c"} + ok = contains("a", table) + if !ok { + t.Fail() + } +} diff --git a/weighed_graph_test.go b/weighed_graph_test.go index 7fcd28ab..0d7252dd 100644 --- a/weighed_graph_test.go +++ b/weighed_graph_test.go @@ -1,15 +1,13 @@ package onnx import ( - "math" - "gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph/iterator" "gonum.org/v1/gonum/graph/simple" ) const ( - self, absent = math.MaxFloat64, float64(-1) + absent = float64(-1) ) type edge struct { @@ -149,7 +147,7 @@ func (g *testExpectedGraph) To(id int64) graph.Nodes { // exists between x and y or if x and y have the same ID, false otherwise. func (g *testExpectedGraph) Weight(xid, yid int64) (w float64, ok bool) { if xid == yid { - return self, true + return SelfEdge, true } if to, ok := g.from[xid]; ok { if e, ok := to[yid]; ok { From 653185afb883a9c0ba360ddfb46dff2fe6daf877 Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Thu, 12 Sep 2019 13:26:14 +0200 Subject: [PATCH 7/7] feat(wip): attempt to add a self-edge; tests are failing --- decoder.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/decoder.go b/decoder.go index aeeaac88..9094ba8d 100644 --- a/decoder.go +++ b/decoder.go @@ -182,7 +182,8 @@ func (m *Model) decodeProto(model *pb.ModelProto) error { m.dbByName[output] = no } // input should be ordered for non-commutatives operations - for i, input := range node.Input { + for i := 0; i < len(node.Input); i++ { + input := node.Input[i] var ni graph.Node var ok bool if ni, ok = m.dbByName[input]; !ok { @@ -194,6 +195,10 @@ func (m *Model) decodeProto(model *pb.ModelProto) error { m.dbByName[input] = ni } e := dst.NewWeightedEdge(no, ni, float64(i)) + if i < len(node.Input)-1 && contains(input, node.Input[i+1:]) { + node.Input = append(node.Input[:i], node.Input[i+1:]...) + e = dst.NewWeightedEdge(no, no, SelfEdge) + } dst.SetWeightedEdge(e) } outputNodes[i] = no