diff --git a/backend.go b/backend.go index 898fff40..e8e51193 100644 --- a/backend.go +++ b/backend.go @@ -1,9 +1,14 @@ package onnx import ( + "math" + "gonum.org/v1/gonum/graph" ) +// SelfEdge is the weight of a self edge in the graph +const SelfEdge = math.MaxFloat64 + // Backend represent any backend able to receive a computation graph type Backend interface { OperationCarrier diff --git a/decoder.go b/decoder.go index aeeaac88..9094ba8d 100644 --- a/decoder.go +++ b/decoder.go @@ -182,7 +182,8 @@ func (m *Model) decodeProto(model *pb.ModelProto) error { m.dbByName[output] = no } // input should be ordered for non-commutatives operations - for i, input := range node.Input { + for i := 0; i < len(node.Input); i++ { + input := node.Input[i] var ni graph.Node var ok bool if ni, ok = m.dbByName[input]; !ok { @@ -194,6 +195,10 @@ func (m *Model) decodeProto(model *pb.ModelProto) error { m.dbByName[input] = ni } e := dst.NewWeightedEdge(no, ni, float64(i)) + if i < len(node.Input)-1 && contains(input, node.Input[i+1:]) { + node.Input = append(node.Input[:i], node.Input[i+1:]...) + e = dst.NewWeightedEdge(no, no, SelfEdge) + } dst.SetWeightedEdge(e) } outputNodes[i] = no diff --git a/dummy_backend_test.go b/dummy_backend_test.go index a50429a3..6ce30666 100644 --- a/dummy_backend_test.go +++ b/dummy_backend_test.go @@ -2,7 +2,6 @@ package onnx import ( "fmt" - "math" "gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph/encoding" @@ -91,7 +90,7 @@ func (n *nodeTest) ApplyTensor(t tensor.Tensor) error { // NewSimpleGraph ... func newTestBackend() *testBackend { return &testBackend{ - g: simple.NewWeightedDirectedGraph(math.MaxFloat64, -1), + g: simple.NewWeightedDirectedGraph(SelfEdge, -1), } } @@ -107,6 +106,7 @@ func (g *testBackend) AddNode(n graph.Node) { g.g.AddNode(n) } + func (g *testBackend) NewNode() graph.Node { n := g.g.NewNode() return &nodeTest{ @@ -151,3 +151,8 @@ func (g *testBackend) ApplyOperation(_ Operation, _ ...graph.Node) error { func (g *testBackend) WeightedEdge(uid, vid int64) graph.WeightedEdge { return g.g.WeightedEdge(uid, vid) } + +// WeightedEdges returns all the weighted edges in the graph. +func (g *testBackend) WeightedEdges() graph.WeightedEdges { + return g.g.WeightedEdges() +} diff --git a/internal/tools/dump/main.go b/internal/tools/dump/main.go new file mode 100644 index 00000000..8889460d --- /dev/null +++ b/internal/tools/dump/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "io/ioutil" + "log" + "os" + + "github.com/owulveryck/onnx-go/internal/pb-onnx" + "github.com/sanity-io/litter" +) + +func main() { + onnxFile := os.Args[1] + b, err := ioutil.ReadFile(onnxFile) + if err != nil { + log.Fatal(err) + } + var m pb.ModelProto + err = m.XXX_Unmarshal(b) + if err != nil { + log.Fatal(err) + } + litter.Dump(m) + +} diff --git a/self_node_test.go b/self_node_test.go new file mode 100644 index 00000000..e4a82ba0 --- /dev/null +++ b/self_node_test.go @@ -0,0 +1,120 @@ +package onnx + +import ( + "testing" + + pb "github.com/owulveryck/onnx-go/internal/pb-onnx" + "gonum.org/v1/gonum/graph" +) + +func TestDecodeProto_self(t *testing.T) { + input := &pb.ModelProto{ + IrVersion: 5, + OpsetImport: []*pb.OperatorSetIdProto{ + &pb.OperatorSetIdProto{ + Domain: "", + Version: 7, + }, + }, + ProducerName: "tf2onnx", + ProducerVersion: "1.5.3", + Domain: "", + ModelVersion: 0, + DocString: "", + Graph: &pb.GraphProto{ + Node: []*pb.NodeProto{ + &pb.NodeProto{ + Input: []string{ + "x:0", + "x:0", + }, + Output: []string{ + "mul:0", + }, + Name: "mul", + OpType: "Mul", + Domain: "", + Attribute: nil, + DocString: "", + }, + }, + Name: "tf2onnx", + Initializer: nil, + DocString: "converted from ./model_nowind_test/export/", + Input: []*pb.ValueInfoProto{ + &pb.ValueInfoProto{ + Name: "x:0", + Type: &pb.TypeProto{ + Value: &pb.TypeProto_TensorType{ + TensorType: &pb.TypeProto_Tensor{ + ElemType: 1, + Shape: &pb.TensorShapeProto{ + Dim: []*pb.TensorShapeProto_Dimension{ + &pb.TensorShapeProto_Dimension{ + Value: &pb.TensorShapeProto_Dimension_DimValue{ + DimValue: 1, + }, + Denotation: "", + }, + }, + }, + }, + }, + Denotation: "", + }, + DocString: "", + }, + }, + Output: []*pb.ValueInfoProto{ + &pb.ValueInfoProto{ + Name: "mul:0", + Type: &pb.TypeProto{ + Value: &pb.TypeProto_TensorType{ + TensorType: &pb.TypeProto_Tensor{ + ElemType: 1, + Shape: &pb.TensorShapeProto{ + Dim: []*pb.TensorShapeProto_Dimension{ + &pb.TensorShapeProto_Dimension{ + Value: &pb.TensorShapeProto_Dimension_DimValue{ + DimValue: 1, + }, + Denotation: "", + }, + }, + }, + }, + }, + Denotation: "", + }, + DocString: "", + }, + }, + ValueInfo: nil, + }, + MetadataProps: nil, + } + backend := newTestBackend() + + m := NewModel(backend) + err := m.decodeProto(input) + if err != nil { + t.Fatal(err) + } + edges := backend.WeightedEdges() + if edges.Len() != 2 { + t.Fatal("expected 2 weighted edges") + } + ee := make([]graph.WeightedEdge, 2) + for i := 0; edges.Next(); i++ { + ee[i] = edges.WeightedEdge() + } + for i := 0; i < len(ee); i++ { + if ee[i].From() == ee[i].To() && ee[i].Weight() == SelfEdge { + ee = ee[:len(ee)-1] + } + if ee[i].From() != ee[i].To() && ee[i].Weight() == 1 { + ee = ee[:len(ee)-1] + } + } + +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 00000000..642b2c86 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,20 @@ +package onnx + +import "testing" + +func TestContains(t *testing.T) { + table := []string{"a", "b", "c"} + ok := contains("a", table) + if !ok { + t.Fail() + } + ok = contains("z", table) + if ok { + t.Fail() + } + table = []string{"a", "a", "b", "c"} + ok = contains("a", table) + if !ok { + t.Fail() + } +} diff --git a/weighed_graph_test.go b/weighed_graph_test.go index 7fcd28ab..0d7252dd 100644 --- a/weighed_graph_test.go +++ b/weighed_graph_test.go @@ -1,15 +1,13 @@ package onnx import ( - "math" - "gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph/iterator" "gonum.org/v1/gonum/graph/simple" ) const ( - self, absent = math.MaxFloat64, float64(-1) + absent = float64(-1) ) type edge struct { @@ -149,7 +147,7 @@ func (g *testExpectedGraph) To(id int64) graph.Nodes { // exists between x and y or if x and y have the same ID, false otherwise. func (g *testExpectedGraph) Weight(xid, yid int64) (w float64, ok bool) { if xid == yid { - return self, true + return SelfEdge, true } if to, ok := g.from[xid]; ok { if e, ok := to[yid]; ok {