Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ type Node struct {
// Used when we want to break between the field name and values when a
// single-line node exceeds the requested wrap column.
PutSingleValueOnNextLine bool
// Field number from proto definition (0 if unknown/not applicable).
FieldNumber int32
}

// NodeLess is a sorting function that compares two *Nodes, possibly using the parent Node
Expand Down Expand Up @@ -267,6 +269,32 @@ func ByFieldSubfieldPath(field string, subfieldPath []string, projection func(st
}
}

// ByFieldNumber is a NodeLess function that orders fields by their field numbers.
// Field numbers are populated during parsing from descriptor information.
func ByFieldNumber(_, ni, nj *Node, isWholeSlice bool) bool {
if !isWholeSlice {
return false
}

numI, numJ := ni.FieldNumber, nj.FieldNumber

// If both have field numbers, sort by field number
if numI > 0 && numJ > 0 {
return numI < numJ
}

// If only one has field number, prioritize it
if numI > 0 && numJ == 0 {
return true // ni has priority
}
if numI == 0 && numJ > 0 {
return false // nj has priority
}

// If neither has field number, fall back to alphabetical order
return ni.Name < nj.Name
}

// getChildValue returns the Value of the child with the given field name,
// or nil if no single such child exists.
func (n *Node) getChildValue(field string) *Value {
Expand Down
6 changes: 6 additions & 0 deletions cmd/txtpbfmt/fmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ var (
expandAllChildren = flag.Bool("expand_all_children", false, "Expand all children irrespective of initial state.")
skipAllColons = flag.Bool("skip_all_colons", false, "Skip colons whenever possible.")
sortFieldsByFieldName = flag.Bool("sort_fields_by_field_name", false, "Sort fields by field name.")
sortFieldsByFieldNumber = flag.Bool("sort_fields_by_field_number", false, "Sort fields by field number from proto definition.")
protoDescriptor = flag.String("proto_descriptor", "", "Path to protobuf descriptor file (.desc)")
messageFullName = flag.String("message_full_name", "", "Full message type name for field number lookup (required, e.g. google.protobuf.Any)")
sortRepeatedFieldsByContent = flag.Bool("sort_repeated_fields_by_content", false, "Sort adjacent scalar fields of the same field name by their contents.")
sortRepeatedFieldsBySubfield = flag.String("sort_repeated_fields_by_subfield", "", "Sort adjacent message fields of the given field name by the contents of the given subfield.")
removeDuplicateValuesForRepeatedFields = flag.Bool("remove_duplicate_values_for_repeated_fields", false, "Remove lines that have the same field name and scalar value as another.")
Expand Down Expand Up @@ -88,6 +91,9 @@ func processPath(path string) error {
ExpandAllChildren: *expandAllChildren,
SkipAllColons: *skipAllColons,
SortFieldsByFieldName: *sortFieldsByFieldName,
SortFieldsByFieldNumber: *sortFieldsByFieldNumber,
ProtoDescriptor: *protoDescriptor,
MessageFullName: *messageFullName,
SortRepeatedFieldsByContent: *sortRepeatedFieldsByContent,
SortRepeatedFieldsBySubfield: strings.Split(*sortRepeatedFieldsBySubfield, ","),
RemoveDuplicateValuesForRepeatedFields: *removeDuplicateValuesForRepeatedFields,
Expand Down
9 changes: 9 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ type Config struct {
// Sort fields by field name.
SortFieldsByFieldName bool

// Sort fields by field number from proto definition.
SortFieldsByFieldNumber bool

// Path to protobuf descriptor file (.desc).
ProtoDescriptor string

// Full message type name for field number lookup (required, e.g. google.protobuf.Any).
MessageFullName string

// Sort adjacent scalar fields of the same field name by their contents.
SortRepeatedFieldsByContent bool

Expand Down
83 changes: 83 additions & 0 deletions descriptor/descriptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Package descriptor provides functionality to load and parse Protocol Buffer descriptor files.
package descriptor

import (
"fmt"
"os"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"

"google.golang.org/protobuf/types/descriptorpb"
)

// Loader provides functionality to load field numbers from descriptor files.
type Loader struct {
descriptorFile string
files *protoregistry.Files
}

// NewLoader creates a new descriptor loader for the given descriptor file.
func NewLoader(descriptorFile string) (*Loader, error) {
if descriptorFile == "" {
return nil, fmt.Errorf("descriptor file is required")
}

data, err := os.ReadFile(descriptorFile)
if err != nil {
return nil, fmt.Errorf("failed to read descriptor file %s: %v", descriptorFile, err)
}

fileDescSet := &descriptorpb.FileDescriptorSet{}
if err := proto.Unmarshal(data, fileDescSet); err != nil {
return nil, fmt.Errorf("failed to unmarshal descriptor file %s: %v", descriptorFile, err)
}

files, err := protodesc.NewFiles(fileDescSet)
if err != nil {
return nil, fmt.Errorf("failed to create files from descriptor file %s: %v", descriptorFile, err)
}

return &Loader{
descriptorFile: descriptorFile,
files: files,
}, nil
}

// GetRootMessageDescriptor returns the root message descriptor for the specified messageFullName.
// messageFullName is required and must be a valid full name (e.g., "google.protobuf.Any").
func (l *Loader) GetRootMessageDescriptor(messageFullName string) (protoreflect.MessageDescriptor, error) {
if l.files == nil {
return nil, fmt.Errorf("descriptor not loaded, call NewLoader() first")
}

if messageFullName == "" {
// Collect available messages to help user
var availableMessages []string
l.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
messages := fd.Messages()
for i := 0; i < messages.Len(); i++ {
msg := messages.Get(i)
availableMessages = append(availableMessages, string(msg.FullName()))
}
return true
})

if len(availableMessages) == 0 {
return nil, fmt.Errorf("No messages found in descriptor")
}
return nil, fmt.Errorf("message_full_name is required. Available messages: %v", availableMessages)
}

// Find specific message type
desc, err := l.files.FindDescriptorByName(protoreflect.FullName(messageFullName))
if err != nil {
return nil, fmt.Errorf("message type %s not found: %v", messageFullName, err)
}
if msgDesc, ok := desc.(protoreflect.MessageDescriptor); ok {
return msgDesc, nil
}
return nil, fmt.Errorf("%s is not a message type", messageFullName)
}
74 changes: 74 additions & 0 deletions descriptor/descriptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package descriptor

import (
"testing"

// Google internal testing/gobase/runfilestest package, commented out by copybara
)

func TestNewLoader(t *testing.T) {
t.Run("valid descriptor file", func(t *testing.T) {
descriptorFile := "../testdata/test.desc"
loader, err := NewLoader(descriptorFile)
if err != nil {
t.Fatalf("Failed to create loader: %v", err)
}
if loader == nil {
t.Fatal("Expected non-nil loader")
}
})

t.Run("empty descriptor file path", func(t *testing.T) {
_, err := NewLoader("")
if err == nil {
t.Error("Expected error for empty path")
}
})

t.Run("non-existent file", func(t *testing.T) {
_, err := NewLoader("nonexistent.desc")
if err == nil {
t.Error("Expected error for non-existent file")
}
})
}

func TestGetRootMessageDescriptor(t *testing.T) {
descriptorFile := "../testdata/test.desc"
loader, err := NewLoader(descriptorFile)
if err != nil {
t.Fatalf("Failed to create loader: %v", err)
}

tests := []struct {
name string
messageFullName string
wantError bool
}{
{"UserProfile", "testproto.UserProfile", false},
{"ProductCatalog", "testproto.ProductCatalog", false},
{"Level1Config", "testproto.Level1Config", false},
{"nested message", "testproto.Level1Config.Level2Config", false},
{"empty name", "", true},
{"non-existent", "testproto.NonExistent", true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
desc, err := loader.GetRootMessageDescriptor(tt.messageFullName)

if tt.wantError {
if err == nil {
t.Error("Expected error but got none")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if desc == nil {
t.Error("Expected descriptor but got nil")
}
}
})
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ require (
github.com/google/go-cmp v0.6.0
github.com/kylelemons/godebug v1.1.0
github.com/mitchellh/go-wordwrap v1.0.1
google.golang.org/protobuf v1.33.0
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
67 changes: 61 additions & 6 deletions impl/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"strconv"
"strings"

"google.golang.org/protobuf/reflect/protoreflect"
"github.com/protocolbuffers/txtpbfmt/ast"
"github.com/protocolbuffers/txtpbfmt/config"
"github.com/protocolbuffers/txtpbfmt/descriptor"
"github.com/protocolbuffers/txtpbfmt/quote"
"github.com/protocolbuffers/txtpbfmt/sort"
"github.com/protocolbuffers/txtpbfmt/wrap"
Expand Down Expand Up @@ -148,13 +150,33 @@ func ParseWithMetaCommentConfig(in []byte, c config.Config) ([]*ast.Node, error)
if err != nil {
return nil, err
}

// Load descriptor if field number sorting is enabled
var rootDesc protoreflect.MessageDescriptor
if c.SortFieldsByFieldNumber {
if c.ProtoDescriptor == "" {
return nil, fmt.Errorf("proto_descriptor is required when using sort_fields_by_field_number")
}

loader, err := descriptor.NewLoader(c.ProtoDescriptor)
if err != nil {
return nil, fmt.Errorf("failed to create descriptor loader: %v", err)
}

// Get root message descriptor
rootDesc, err = loader.GetRootMessageDescriptor(c.MessageFullName)
if err != nil {
return nil, fmt.Errorf("failed to get root message descriptor: %v", err)
}
}

if p.config.InfoLevel() {
p.config.Infof("p.in: %q", string(p.in))
p.config.Infof("p.length: %v", p.length)
}
// Although unnamed nodes aren't strictly allowed, some formats represent a
// list of protos as a list of unnamed top-level nodes.
nodes, _, err := p.parse( /*isRoot=*/ true)
nodes, _, err := p.parse( /*isRoot=*/ true, rootDesc)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -288,6 +310,35 @@ func newParser(in []byte, c config.Config) (*parser, error) {
return parser, nil
}

// getFieldNumber returns the field number for a given field name in the descriptor.
func getFieldNumber(desc protoreflect.MessageDescriptor, fieldName string) int32 {
if desc == nil {
return 0
}

field := desc.Fields().ByTextName(fieldName)
if field == nil {
return 0
}
return int32(field.Number())
}

// findChildDescriptor finds the descriptor for a nested message field.
func (p *parser) findChildDescriptor(desc protoreflect.MessageDescriptor, fieldName string) protoreflect.MessageDescriptor {
if desc == nil {
return nil
}

field := desc.Fields().ByTextName(fieldName)
if field == nil {
return nil
}
if field.Kind() == protoreflect.MessageKind {
return field.Message()
}
return nil
}

func (p *parser) nextInputIs(b byte) bool {
return p.index < p.length && p.in[p.index] == b
}
Expand Down Expand Up @@ -398,7 +449,7 @@ func (p *parser) consumeOptionalSeparator() error {
// format (sequence of messages, each of which passes proto.UnmarshalText()).
// endPos is the position of the first character on the first line
// after parsed nodes: that's the position to append more children.
func (p *parser) parse(isRoot bool) (result []*ast.Node, endPos ast.Position, err error) {
func (p *parser) parse(isRoot bool, desc protoreflect.MessageDescriptor) (result []*ast.Node, endPos ast.Position, err error) {
var res []*ast.Node
res = []*ast.Node{} // empty children is different from nil children
for ld := p.getLoopDetector(); p.index < p.length; {
Expand Down Expand Up @@ -505,14 +556,17 @@ func (p *parser) parse(isRoot bool) (result []*ast.Node, endPos ast.Position, er
return nil, ast.Position{}, err
}

// Set field number from descriptor if available
nd.FieldNumber = getFieldNumber(desc, nd.Name)

// Skip separator.
preCommentsBeforeColon, _ := p.skipWhiteSpaceAndReadComments(true /* multiLine */)
nd.SkipColon = !p.consume(':')
previousPos := p.position()
preCommentsAfterColon, _ := p.skipWhiteSpaceAndReadComments(true /* multiLine */)

if p.consume('{') || p.consume('<') {
if err := p.parseMessage(nd); err != nil {
if err := p.parseMessage(nd, desc); err != nil {
return nil, ast.Position{}, err
}
} else if p.consume('[') {
Expand Down Expand Up @@ -562,14 +616,15 @@ func (p *parser) parseFieldName(nd *ast.Node, isRoot bool) error {
return nil
}

func (p *parser) parseMessage(nd *ast.Node) error {
func (p *parser) parseMessage(nd *ast.Node, desc protoreflect.MessageDescriptor) error {
if p.config.SkipAllColons {
nd.SkipColon = true
}
nd.ChildrenSameLine = p.bracketSameLine[p.index-1]
nd.IsAngleBracket = p.config.PreserveAngleBrackets && p.in[p.index-1] == '<'
// Recursive call to parse child nodes.
nodes, lastPos, err := p.parse( /*isRoot=*/ false)
childDesc := p.findChildDescriptor(desc, nd.Name)
nodes, lastPos, err := p.parse( /*isRoot=*/ false, childDesc)
if err != nil {
return err
}
Expand All @@ -595,7 +650,7 @@ func (p *parser) parseList(nd *ast.Node, preCommentsBeforeColon, preCommentsAfte
// Handle list of nodes.
nd.ChildrenAsList = true

nodes, lastPos, err := p.parse( /*isRoot=*/ true)
nodes, lastPos, err := p.parse( /*isRoot=*/ true, nil)
if err != nil {
return err
}
Expand Down
Loading
Loading