From ee929ab88d09e61a93cd041730adfedb1aa2c510 Mon Sep 17 00:00:00 2001
From: Tibor Vass <teabee89@gmail.com>
Date: Fri, 7 Aug 2015 12:49:35 -0700
Subject: [PATCH 1/3] Expose Encoder.Canonical() and MarshalCanonical() Handles
 lexicographic order in struct fields, rejects floating numbers and handles
 strings as defined in http://wiki.laptop.org/go/Canonical_JSON except for
 unicode normalization. IOW, only escaping allowed is " and \.

The notion of Canonical JSON is only for an Encoder.

Signed-off-by: Tibor Vass <teabee89@gmail.com>
---
 canonical/json/decode.go      |   3 +-
 canonical/json/encode.go      | 141 ++++++++++++++++++++++++++++--------------
 canonical/json/encode_test.go |  50 +++++++++++++++
 canonical/json/stream.go      |  27 +++++---
 4 files changed, 162 insertions(+), 59 deletions(-)

diff --git a/canonical/json/decode.go b/canonical/json/decode.go
index 705bc2e..51b4ffe 100644
--- a/canonical/json/decode.go
+++ b/canonical/json/decode.go
@@ -174,6 +174,7 @@ type decodeState struct {
 	nextscan   scanner // for calls to nextValue
 	savedError error
 	useNumber  bool
+	canonical  bool
 }
 
 // errPhase is used for errors that should not happen unless
@@ -557,7 +558,7 @@ func (d *decodeState) object(v reflect.Value) {
 			subv = mapElem
 		} else {
 			var f *field
-			fields := cachedTypeFields(v.Type())
+			fields := cachedTypeFields(v.Type(), false)
 			for i := range fields {
 				ff := &fields[i]
 				if bytes.Equal(ff.nameBytes, key) {
diff --git a/canonical/json/encode.go b/canonical/json/encode.go
index fca2a09..aaa79c2 100644
--- a/canonical/json/encode.go
+++ b/canonical/json/encode.go
@@ -131,12 +131,7 @@ import (
 // an infinite recursion.
 //
 func Marshal(v interface{}) ([]byte, error) {
-	e := &encodeState{}
-	err := e.marshal(v)
-	if err != nil {
-		return nil, err
-	}
-	return e.Bytes(), nil
+	return marshal(v, false)
 }
 
 // MarshalIndent is like Marshal but applies Indent to format the output.
@@ -153,6 +148,21 @@ func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) {
 	return buf.Bytes(), nil
 }
 
+// MarshalCanonical is like Marshal but encodes into Canonical JSON.
+// Read more at: http://wiki.laptop.org/go/Canonical_JSON
+func MarshalCanonical(v interface{}) ([]byte, error) {
+	return marshal(v, true)
+}
+
+func marshal(v interface{}, canonical bool) ([]byte, error) {
+	e := &encodeState{canonical: canonical}
+	err := e.marshal(v)
+	if err != nil {
+		return nil, err
+	}
+	return e.Bytes(), nil
+}
+
 // HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
 // characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
 // so that the JSON will be safe to embed inside HTML <script> tags.
@@ -242,17 +252,19 @@ var hex = "0123456789abcdef"
 type encodeState struct {
 	bytes.Buffer // accumulated output
 	scratch      [64]byte
+	canonical    bool
 }
 
 var encodeStatePool sync.Pool
 
-func newEncodeState() *encodeState {
+func newEncodeState(canonical bool) *encodeState {
 	if v := encodeStatePool.Get(); v != nil {
 		e := v.(*encodeState)
 		e.Reset()
+		e.canonical = canonical
 		return e
 	}
-	return new(encodeState)
+	return &encodeState{canonical: canonical}
 }
 
 func (e *encodeState) marshal(v interface{}) (err error) {
@@ -296,7 +308,7 @@ func isEmptyValue(v reflect.Value) bool {
 }
 
 func (e *encodeState) reflectValue(v reflect.Value) {
-	valueEncoder(v)(e, v, false)
+	e.valueEncoder(v)(e, v, false)
 }
 
 type encoderFunc func(e *encodeState, v reflect.Value, quoted bool)
@@ -306,14 +318,14 @@ var encoderCache struct {
 	m map[reflect.Type]encoderFunc
 }
 
-func valueEncoder(v reflect.Value) encoderFunc {
+func (e *encodeState) valueEncoder(v reflect.Value) encoderFunc {
 	if !v.IsValid() {
 		return invalidValueEncoder
 	}
-	return typeEncoder(v.Type())
+	return e.typeEncoder(v.Type())
 }
 
-func typeEncoder(t reflect.Type) encoderFunc {
+func (e *encodeState) typeEncoder(t reflect.Type) encoderFunc {
 	encoderCache.RLock()
 	f := encoderCache.m[t]
 	encoderCache.RUnlock()
@@ -339,7 +351,7 @@ func typeEncoder(t reflect.Type) encoderFunc {
 
 	// Compute fields without lock.
 	// Might duplicate effort but won't hold other computations back.
-	f = newTypeEncoder(t, true)
+	f = e.newTypeEncoder(t, true)
 	wg.Done()
 	encoderCache.Lock()
 	encoderCache.m[t] = f
@@ -354,13 +366,13 @@ var (
 
 // newTypeEncoder constructs an encoderFunc for a type.
 // The returned encoder only checks CanAddr when allowAddr is true.
-func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
+func (e *encodeState) newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
 	if t.Implements(marshalerType) {
 		return marshalerEncoder
 	}
 	if t.Kind() != reflect.Ptr && allowAddr {
 		if reflect.PtrTo(t).Implements(marshalerType) {
-			return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
+			return newCondAddrEncoder(addrMarshalerEncoder, e.newTypeEncoder(t, false))
 		}
 	}
 
@@ -369,7 +381,7 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
 	}
 	if t.Kind() != reflect.Ptr && allowAddr {
 		if reflect.PtrTo(t).Implements(textMarshalerType) {
-			return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
+			return newCondAddrEncoder(addrTextMarshalerEncoder, e.newTypeEncoder(t, false))
 		}
 	}
 
@@ -389,15 +401,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
 	case reflect.Interface:
 		return interfaceEncoder
 	case reflect.Struct:
-		return newStructEncoder(t)
+		return e.newStructEncoder(t)
 	case reflect.Map:
-		return newMapEncoder(t)
+		return e.newMapEncoder(t)
 	case reflect.Slice:
-		return newSliceEncoder(t)
+		return e.newSliceEncoder(t)
 	case reflect.Array:
-		return newArrayEncoder(t)
+		return e.newArrayEncoder(t)
 	case reflect.Ptr:
-		return newPtrEncoder(t)
+		return e.newPtrEncoder(t)
 	default:
 		return unsupportedTypeEncoder
 	}
@@ -511,7 +523,7 @@ type floatEncoder int // number of bits
 
 func (bits floatEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
 	f := v.Float()
-	if math.IsInf(f, 0) || math.IsNaN(f) {
+	if math.IsInf(f, 0) || math.IsNaN(f) || (e.canonical && math.Floor(f) != f) {
 		e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))})
 	}
 	b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, int(bits))
@@ -586,14 +598,14 @@ func (se *structEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
 	e.WriteByte('}')
 }
 
-func newStructEncoder(t reflect.Type) encoderFunc {
-	fields := cachedTypeFields(t)
+func (e *encodeState) newStructEncoder(t reflect.Type) encoderFunc {
+	fields := cachedTypeFields(t, e.canonical)
 	se := &structEncoder{
 		fields:    fields,
 		fieldEncs: make([]encoderFunc, len(fields)),
 	}
 	for i, f := range fields {
-		se.fieldEncs[i] = typeEncoder(typeByIndex(t, f.index))
+		se.fieldEncs[i] = e.typeEncoder(typeByIndex(t, f.index))
 	}
 	return se.encode
 }
@@ -621,11 +633,11 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
 	e.WriteByte('}')
 }
 
-func newMapEncoder(t reflect.Type) encoderFunc {
+func (e *encodeState) newMapEncoder(t reflect.Type) encoderFunc {
 	if t.Key().Kind() != reflect.String {
 		return unsupportedTypeEncoder
 	}
-	me := &mapEncoder{typeEncoder(t.Elem())}
+	me := &mapEncoder{e.typeEncoder(t.Elem())}
 	return me.encode
 }
 
@@ -664,12 +676,12 @@ func (se *sliceEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
 	se.arrayEnc(e, v, false)
 }
 
-func newSliceEncoder(t reflect.Type) encoderFunc {
+func (e *encodeState) newSliceEncoder(t reflect.Type) encoderFunc {
 	// Byte slices get special treatment; arrays don't.
 	if t.Elem().Kind() == reflect.Uint8 {
 		return encodeByteSlice
 	}
-	enc := &sliceEncoder{newArrayEncoder(t)}
+	enc := &sliceEncoder{e.newArrayEncoder(t)}
 	return enc.encode
 }
 
@@ -689,8 +701,8 @@ func (ae *arrayEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
 	e.WriteByte(']')
 }
 
-func newArrayEncoder(t reflect.Type) encoderFunc {
-	enc := &arrayEncoder{typeEncoder(t.Elem())}
+func (e *encodeState) newArrayEncoder(t reflect.Type) encoderFunc {
+	enc := &arrayEncoder{e.typeEncoder(t.Elem())}
 	return enc.encode
 }
 
@@ -706,8 +718,8 @@ func (pe *ptrEncoder) encode(e *encodeState, v reflect.Value, quoted bool) {
 	pe.elemEnc(e, v.Elem(), quoted)
 }
 
-func newPtrEncoder(t reflect.Type) encoderFunc {
-	enc := &ptrEncoder{typeEncoder(t.Elem())}
+func (e *encodeState) newPtrEncoder(t reflect.Type) encoderFunc {
+	enc := &ptrEncoder{e.typeEncoder(t.Elem())}
 	return enc.encode
 }
 
@@ -788,9 +800,11 @@ func (e *encodeState) string(s string) (int, error) {
 	start := 0
 	for i := 0; i < len(s); {
 		if b := s[i]; b < utf8.RuneSelf {
-			if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
-				i++
-				continue
+			if b != '\\' && b != '"' {
+				if e.canonical || (0x20 <= b && b != '<' && b != '>' && b != '&') {
+					i++
+					continue
+				}
 			}
 			if start < i {
 				e.WriteString(s[start:i])
@@ -821,6 +835,10 @@ func (e *encodeState) string(s string) (int, error) {
 			start = i
 			continue
 		}
+		if e.canonical {
+			i++
+			continue
+		}
 		c, size := utf8.DecodeRuneInString(s[i:])
 		if c == utf8.RuneError && size == 1 {
 			if start < i {
@@ -864,9 +882,11 @@ func (e *encodeState) stringBytes(s []byte) (int, error) {
 	start := 0
 	for i := 0; i < len(s); {
 		if b := s[i]; b < utf8.RuneSelf {
-			if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
-				i++
-				continue
+			if b != '\\' && b != '"' {
+				if e.canonical || (0x20 <= b && b != '<' && b != '>' && b != '&') {
+					i++
+					continue
+				}
 			}
 			if start < i {
 				e.Write(s[start:i])
@@ -897,6 +917,10 @@ func (e *encodeState) stringBytes(s []byte) (int, error) {
 			start = i
 			continue
 		}
+		if e.canonical {
+			i++
+			continue
+		}
 		c, size := utf8.DecodeRune(s[i:])
 		if c == utf8.RuneError && size == 1 {
 			if start < i {
@@ -1108,10 +1132,7 @@ func typeFields(t reflect.Type) []field {
 		}
 	}
 
-	fields = out
-	sort.Sort(byIndex(fields))
-
-	return fields
+	return out
 }
 
 // dominantField looks through the fields, all of which are known to
@@ -1152,16 +1173,29 @@ func dominantField(fields []field) (field, bool) {
 	return fields[0], true
 }
 
+type fields struct {
+	byName  []field
+	byIndex []field
+}
+
 var fieldCache struct {
 	sync.RWMutex
-	m map[reflect.Type][]field
+	m map[reflect.Type]*fields
 }
 
 // cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
-func cachedTypeFields(t reflect.Type) []field {
+func cachedTypeFields(t reflect.Type, canonical bool) []field {
 	fieldCache.RLock()
-	f := fieldCache.m[t]
+	x := fieldCache.m[t]
 	fieldCache.RUnlock()
+
+	var f []field
+	if x != nil {
+		if canonical {
+			f = x.byName
+		}
+		f = x.byIndex
+	}
 	if f != nil {
 		return f
 	}
@@ -1172,12 +1206,23 @@ func cachedTypeFields(t reflect.Type) []field {
 	if f == nil {
 		f = []field{}
 	}
+	if !canonical {
+		sort.Sort(byIndex(f))
+	}
 
 	fieldCache.Lock()
 	if fieldCache.m == nil {
-		fieldCache.m = map[reflect.Type][]field{}
+		fieldCache.m = map[reflect.Type]*fields{}
 	}
-	fieldCache.m[t] = f
+	x = fieldCache.m[t]
 	fieldCache.Unlock()
+	if x == nil {
+		x = new(fields)
+	}
+	if canonical {
+		x.byName = f
+	} else {
+		x.byIndex = f
+	}
 	return f
 }
diff --git a/canonical/json/encode_test.go b/canonical/json/encode_test.go
index 7abfa85..87c697e 100644
--- a/canonical/json/encode_test.go
+++ b/canonical/json/encode_test.go
@@ -6,6 +6,7 @@ package json
 
 import (
 	"bytes"
+	"fmt"
 	"math"
 	"reflect"
 	"testing"
@@ -530,3 +531,52 @@ func TestEncodeString(t *testing.T) {
 		}
 	}
 }
+
+type CanonicalTestStruct struct {
+	S string
+	F float64
+	I int
+	E *CanonicalTestStruct
+}
+
+func (s *CanonicalTestStruct) String() string {
+	var e interface{} = s.E
+	if s.E == nil {
+		e = "nil"
+	}
+	return fmt.Sprintf("{S:%q F:%v I:%v E:%v}", s.S, s.F, s.I, e)
+}
+
+var encodeCanonicalTests = []struct {
+	in        interface{}
+	out       string
+	expectErr bool
+}{
+	{nil, `null`, false},
+	{&CanonicalTestStruct{}, `{"E":null,"F":0,"I":0,"S":""}`, false},
+	{&CanonicalTestStruct{F: 1.0}, `{"E":null,"F":1,"I":0,"S":""}`, false},
+	// error out on floating numbers
+	{&CanonicalTestStruct{F: 1.2}, ``, true},
+	{&CanonicalTestStruct{S: "foo", E: &CanonicalTestStruct{I: 42}}, `{"E":{"E":null,"F":0,"I":42,"S":""},"F":0,"I":0,"S":"foo"}`, false},
+	// only escape \ and " and keep any other character as-is
+	{"\u0090 \t \\ \n \"", "\"\u0090 \t \\\\ \n \\\"\"", false},
+}
+
+func TestEncodeCanonicalStruct(t *testing.T) {
+	for _, tt := range encodeCanonicalTests {
+		b, err := MarshalCanonical(tt.in)
+		if err != nil {
+			if !tt.expectErr {
+				t.Errorf("MarshalCanonical(%#v) = error(%v), want %s", tt.in, err, tt.out)
+			}
+			continue
+		} else if tt.expectErr {
+			t.Errorf("MarshalCanonical(%#v) expects an error", tt.in)
+			continue
+		}
+		out := string(b)
+		if out != tt.out {
+			t.Errorf("MarshalCanonical(%#v) = %q, want %q", tt.in, out, tt.out)
+		}
+	}
+}
diff --git a/canonical/json/stream.go b/canonical/json/stream.go
index 9566eca..8905550 100644
--- a/canonical/json/stream.go
+++ b/canonical/json/stream.go
@@ -138,8 +138,9 @@ func nonSpace(b []byte) bool {
 
 // An Encoder writes JSON objects to an output stream.
 type Encoder struct {
-	w   io.Writer
-	err error
+	w         io.Writer
+	err       error
+	canonical bool
 }
 
 // NewEncoder returns a new encoder that writes to w.
@@ -147,6 +148,10 @@ func NewEncoder(w io.Writer) *Encoder {
 	return &Encoder{w: w}
 }
 
+// Canonical causes the encoder to switch to Canonical JSON mode.
+// Read more at: http://wiki.laptop.org/go/Canonical_JSON
+func (enc *Encoder) Canonical() { enc.canonical = true }
+
 // Encode writes the JSON encoding of v to the stream,
 // followed by a newline character.
 //
@@ -156,19 +161,21 @@ func (enc *Encoder) Encode(v interface{}) error {
 	if enc.err != nil {
 		return enc.err
 	}
-	e := newEncodeState()
+	e := newEncodeState(enc.canonical)
 	err := e.marshal(v)
 	if err != nil {
 		return err
 	}
 
-	// Terminate each value with a newline.
-	// This makes the output look a little nicer
-	// when debugging, and some kind of space
-	// is required if the encoded value was a number,
-	// so that the reader knows there aren't more
-	// digits coming.
-	e.WriteByte('\n')
+	if !enc.canonical {
+		// Terminate each value with a newline.
+		// This makes the output look a little nicer
+		// when debugging, and some kind of space
+		// is required if the encoded value was a number,
+		// so that the reader knows there aren't more
+		// digits coming.
+		e.WriteByte('\n')
+	}
 
 	if _, err = enc.w.Write(e.Bytes()); err != nil {
 		enc.err = err
-- 
2.5.0

