Skip to content
Draft
5 changes: 5 additions & 0 deletions backend.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package onnx

import (
"math"

"gonum.org/v1/gonum/graph"
)

// SelfEdge is the weight of a self edge in the graph
const SelfEdge = math.MaxFloat64

// Backend represent any backend able to receive a computation graph
type Backend interface {
OperationCarrier
Expand Down
7 changes: 6 additions & 1 deletion decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ func (m *Model) decodeProto(model *pb.ModelProto) error {
m.dbByName[output] = no
}
// input should be ordered for non-commutatives operations
for i, input := range node.Input {
for i := 0; i < len(node.Input); i++ {
input := node.Input[i]
var ni graph.Node
var ok bool
if ni, ok = m.dbByName[input]; !ok {
Expand All @@ -194,6 +195,10 @@ func (m *Model) decodeProto(model *pb.ModelProto) error {
m.dbByName[input] = ni
}
e := dst.NewWeightedEdge(no, ni, float64(i))
if i < len(node.Input)-1 && contains(input, node.Input[i+1:]) {
node.Input = append(node.Input[:i], node.Input[i+1:]...)
e = dst.NewWeightedEdge(no, no, SelfEdge)
}
dst.SetWeightedEdge(e)
}
outputNodes[i] = no
Expand Down
9 changes: 7 additions & 2 deletions dummy_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package onnx

import (
"fmt"
"math"

"gonum.org/v1/gonum/graph"
"gonum.org/v1/gonum/graph/encoding"
Expand Down Expand Up @@ -91,7 +90,7 @@ func (n *nodeTest) ApplyTensor(t tensor.Tensor) error {
// NewSimpleGraph ...
func newTestBackend() *testBackend {
return &testBackend{
g: simple.NewWeightedDirectedGraph(math.MaxFloat64, -1),
g: simple.NewWeightedDirectedGraph(SelfEdge, -1),
}
}

Expand All @@ -107,6 +106,7 @@ func (g *testBackend) AddNode(n graph.Node) {
g.g.AddNode(n)

}

func (g *testBackend) NewNode() graph.Node {
n := g.g.NewNode()
return &nodeTest{
Expand Down Expand Up @@ -151,3 +151,8 @@ func (g *testBackend) ApplyOperation(_ Operation, _ ...graph.Node) error {
func (g *testBackend) WeightedEdge(uid, vid int64) graph.WeightedEdge {
return g.g.WeightedEdge(uid, vid)
}

// WeightedEdges returns all the weighted edges in the graph.
func (g *testBackend) WeightedEdges() graph.WeightedEdges {
return g.g.WeightedEdges()
}
25 changes: 25 additions & 0 deletions internal/tools/dump/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package main

import (
"io/ioutil"
"log"
"os"

"github.com/owulveryck/onnx-go/internal/pb-onnx"
"github.com/sanity-io/litter"
)

func main() {
onnxFile := os.Args[1]
b, err := ioutil.ReadFile(onnxFile)
if err != nil {
log.Fatal(err)
}
var m pb.ModelProto
err = m.XXX_Unmarshal(b)
if err != nil {
log.Fatal(err)
}
litter.Dump(m)

}
120 changes: 120 additions & 0 deletions self_node_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package onnx

import (
"testing"

pb "github.com/owulveryck/onnx-go/internal/pb-onnx"
"gonum.org/v1/gonum/graph"
)

func TestDecodeProto_self(t *testing.T) {
input := &pb.ModelProto{
IrVersion: 5,
OpsetImport: []*pb.OperatorSetIdProto{
&pb.OperatorSetIdProto{
Domain: "",
Version: 7,
},
},
ProducerName: "tf2onnx",
ProducerVersion: "1.5.3",
Domain: "",
ModelVersion: 0,
DocString: "",
Graph: &pb.GraphProto{
Node: []*pb.NodeProto{
&pb.NodeProto{
Input: []string{
"x:0",
"x:0",
},
Output: []string{
"mul:0",
},
Name: "mul",
OpType: "Mul",
Domain: "",
Attribute: nil,
DocString: "",
},
},
Name: "tf2onnx",
Initializer: nil,
DocString: "converted from ./model_nowind_test/export/",
Input: []*pb.ValueInfoProto{
&pb.ValueInfoProto{
Name: "x:0",
Type: &pb.TypeProto{
Value: &pb.TypeProto_TensorType{
TensorType: &pb.TypeProto_Tensor{
ElemType: 1,
Shape: &pb.TensorShapeProto{
Dim: []*pb.TensorShapeProto_Dimension{
&pb.TensorShapeProto_Dimension{
Value: &pb.TensorShapeProto_Dimension_DimValue{
DimValue: 1,
},
Denotation: "",
},
},
},
},
},
Denotation: "",
},
DocString: "",
},
},
Output: []*pb.ValueInfoProto{
&pb.ValueInfoProto{
Name: "mul:0",
Type: &pb.TypeProto{
Value: &pb.TypeProto_TensorType{
TensorType: &pb.TypeProto_Tensor{
ElemType: 1,
Shape: &pb.TensorShapeProto{
Dim: []*pb.TensorShapeProto_Dimension{
&pb.TensorShapeProto_Dimension{
Value: &pb.TensorShapeProto_Dimension_DimValue{
DimValue: 1,
},
Denotation: "",
},
},
},
},
},
Denotation: "",
},
DocString: "",
},
},
ValueInfo: nil,
},
MetadataProps: nil,
}
backend := newTestBackend()

m := NewModel(backend)
err := m.decodeProto(input)
if err != nil {
t.Fatal(err)
}
edges := backend.WeightedEdges()
if edges.Len() != 2 {
t.Fatal("expected 2 weighted edges")
}
ee := make([]graph.WeightedEdge, 2)
for i := 0; edges.Next(); i++ {
ee[i] = edges.WeightedEdge()
}
for i := 0; i < len(ee); i++ {
if ee[i].From() == ee[i].To() && ee[i].Weight() == SelfEdge {
ee = ee[:len(ee)-1]
}
if ee[i].From() != ee[i].To() && ee[i].Weight() == 1 {
ee = ee[:len(ee)-1]
}
}

}
20 changes: 20 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package onnx

import "testing"

func TestContains(t *testing.T) {
table := []string{"a", "b", "c"}
ok := contains("a", table)
if !ok {
t.Fail()
}
ok = contains("z", table)
if ok {
t.Fail()
}
table = []string{"a", "a", "b", "c"}
ok = contains("a", table)
if !ok {
t.Fail()
}
}
6 changes: 2 additions & 4 deletions weighed_graph_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package onnx

import (
"math"

"gonum.org/v1/gonum/graph"
"gonum.org/v1/gonum/graph/iterator"
"gonum.org/v1/gonum/graph/simple"
)

const (
self, absent = math.MaxFloat64, float64(-1)
absent = float64(-1)
)

type edge struct {
Expand Down Expand Up @@ -149,7 +147,7 @@ func (g *testExpectedGraph) To(id int64) graph.Nodes {
// exists between x and y or if x and y have the same ID, false otherwise.
func (g *testExpectedGraph) Weight(xid, yid int64) (w float64, ok bool) {
if xid == yid {
return self, true
return SelfEdge, true
}
if to, ok := g.from[xid]; ok {
if e, ok := to[yid]; ok {
Expand Down