diff.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. package pretty
  2. import (
  3. "fmt"
  4. "io"
  5. "reflect"
  6. )
  7. type sbuf []string
  8. func (p *sbuf) Printf(format string, a ...interface{}) {
  9. s := fmt.Sprintf(format, a...)
  10. *p = append(*p, s)
  11. }
  12. // Diff returns a slice where each element describes
  13. // a difference between a and b.
  14. func Diff(a, b interface{}) (desc []string) {
  15. Pdiff((*sbuf)(&desc), a, b)
  16. return desc
  17. }
  18. // wprintfer calls Fprintf on w for each Printf call
  19. // with a trailing newline.
  20. type wprintfer struct{ w io.Writer }
  21. func (p *wprintfer) Printf(format string, a ...interface{}) {
  22. fmt.Fprintf(p.w, format+"\n", a...)
  23. }
  24. // Fdiff writes to w a description of the differences between a and b.
  25. func Fdiff(w io.Writer, a, b interface{}) {
  26. Pdiff(&wprintfer{w}, a, b)
  27. }
  28. type Printfer interface {
  29. Printf(format string, a ...interface{})
  30. }
  31. // Pdiff prints to p a description of the differences between a and b.
  32. // It calls Printf once for each difference, with no trailing newline.
  33. // The standard library log.Logger is a Printfer.
  34. func Pdiff(p Printfer, a, b interface{}) {
  35. diffPrinter{w: p}.diff(reflect.ValueOf(a), reflect.ValueOf(b))
  36. }
  37. type Logfer interface {
  38. Logf(format string, a ...interface{})
  39. }
  40. // logprintfer calls Fprintf on w for each Printf call
  41. // with a trailing newline.
  42. type logprintfer struct{ l Logfer }
  43. func (p *logprintfer) Printf(format string, a ...interface{}) {
  44. p.l.Logf(format, a...)
  45. }
  46. // Ldiff prints to l a description of the differences between a and b.
  47. // It calls Logf once for each difference, with no trailing newline.
  48. // The standard library testing.T and testing.B are Logfers.
  49. func Ldiff(l Logfer, a, b interface{}) {
  50. Pdiff(&logprintfer{l}, a, b)
  51. }
  52. type diffPrinter struct {
  53. w Printfer
  54. l string // label
  55. }
  56. func (w diffPrinter) printf(f string, a ...interface{}) {
  57. var l string
  58. if w.l != "" {
  59. l = w.l + ": "
  60. }
  61. w.w.Printf(l+f, a...)
  62. }
  63. func (w diffPrinter) diff(av, bv reflect.Value) {
  64. if !av.IsValid() && bv.IsValid() {
  65. w.printf("nil != %# v", formatter{v: bv, quote: true})
  66. return
  67. }
  68. if av.IsValid() && !bv.IsValid() {
  69. w.printf("%# v != nil", formatter{v: av, quote: true})
  70. return
  71. }
  72. if !av.IsValid() && !bv.IsValid() {
  73. return
  74. }
  75. at := av.Type()
  76. bt := bv.Type()
  77. if at != bt {
  78. w.printf("%v != %v", at, bt)
  79. return
  80. }
  81. switch kind := at.Kind(); kind {
  82. case reflect.Bool:
  83. if a, b := av.Bool(), bv.Bool(); a != b {
  84. w.printf("%v != %v", a, b)
  85. }
  86. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  87. if a, b := av.Int(), bv.Int(); a != b {
  88. w.printf("%d != %d", a, b)
  89. }
  90. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  91. if a, b := av.Uint(), bv.Uint(); a != b {
  92. w.printf("%d != %d", a, b)
  93. }
  94. case reflect.Float32, reflect.Float64:
  95. if a, b := av.Float(), bv.Float(); a != b {
  96. w.printf("%v != %v", a, b)
  97. }
  98. case reflect.Complex64, reflect.Complex128:
  99. if a, b := av.Complex(), bv.Complex(); a != b {
  100. w.printf("%v != %v", a, b)
  101. }
  102. case reflect.Array:
  103. n := av.Len()
  104. for i := 0; i < n; i++ {
  105. w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
  106. }
  107. case reflect.Chan, reflect.Func, reflect.UnsafePointer:
  108. if a, b := av.Pointer(), bv.Pointer(); a != b {
  109. w.printf("%#x != %#x", a, b)
  110. }
  111. case reflect.Interface:
  112. w.diff(av.Elem(), bv.Elem())
  113. case reflect.Map:
  114. ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys())
  115. for _, k := range ak {
  116. w := w.relabel(fmt.Sprintf("[%#v]", k))
  117. w.printf("%q != (missing)", av.MapIndex(k))
  118. }
  119. for _, k := range both {
  120. w := w.relabel(fmt.Sprintf("[%#v]", k))
  121. w.diff(av.MapIndex(k), bv.MapIndex(k))
  122. }
  123. for _, k := range bk {
  124. w := w.relabel(fmt.Sprintf("[%#v]", k))
  125. w.printf("(missing) != %q", bv.MapIndex(k))
  126. }
  127. case reflect.Ptr:
  128. switch {
  129. case av.IsNil() && !bv.IsNil():
  130. w.printf("nil != %# v", formatter{v: bv, quote: true})
  131. case !av.IsNil() && bv.IsNil():
  132. w.printf("%# v != nil", formatter{v: av, quote: true})
  133. case !av.IsNil() && !bv.IsNil():
  134. w.diff(av.Elem(), bv.Elem())
  135. }
  136. case reflect.Slice:
  137. lenA := av.Len()
  138. lenB := bv.Len()
  139. if lenA != lenB {
  140. w.printf("%s[%d] != %s[%d]", av.Type(), lenA, bv.Type(), lenB)
  141. break
  142. }
  143. for i := 0; i < lenA; i++ {
  144. w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
  145. }
  146. case reflect.String:
  147. if a, b := av.String(), bv.String(); a != b {
  148. w.printf("%q != %q", a, b)
  149. }
  150. case reflect.Struct:
  151. for i := 0; i < av.NumField(); i++ {
  152. w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i))
  153. }
  154. default:
  155. panic("unknown reflect Kind: " + kind.String())
  156. }
  157. }
  158. func (d diffPrinter) relabel(name string) (d1 diffPrinter) {
  159. d1 = d
  160. if d.l != "" && name[0] != '[' {
  161. d1.l += "."
  162. }
  163. d1.l += name
  164. return d1
  165. }
  166. // keyEqual compares a and b for equality.
  167. // Both a and b must be valid map keys.
  168. func keyEqual(av, bv reflect.Value) bool {
  169. if !av.IsValid() && !bv.IsValid() {
  170. return true
  171. }
  172. if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
  173. return false
  174. }
  175. switch kind := av.Kind(); kind {
  176. case reflect.Bool:
  177. a, b := av.Bool(), bv.Bool()
  178. return a == b
  179. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  180. a, b := av.Int(), bv.Int()
  181. return a == b
  182. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  183. a, b := av.Uint(), bv.Uint()
  184. return a == b
  185. case reflect.Float32, reflect.Float64:
  186. a, b := av.Float(), bv.Float()
  187. return a == b
  188. case reflect.Complex64, reflect.Complex128:
  189. a, b := av.Complex(), bv.Complex()
  190. return a == b
  191. case reflect.Array:
  192. for i := 0; i < av.Len(); i++ {
  193. if !keyEqual(av.Index(i), bv.Index(i)) {
  194. return false
  195. }
  196. }
  197. return true
  198. case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
  199. a, b := av.Pointer(), bv.Pointer()
  200. return a == b
  201. case reflect.Interface:
  202. return keyEqual(av.Elem(), bv.Elem())
  203. case reflect.String:
  204. a, b := av.String(), bv.String()
  205. return a == b
  206. case reflect.Struct:
  207. for i := 0; i < av.NumField(); i++ {
  208. if !keyEqual(av.Field(i), bv.Field(i)) {
  209. return false
  210. }
  211. }
  212. return true
  213. default:
  214. panic("invalid map key type " + av.Type().String())
  215. }
  216. }
  217. func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) {
  218. for _, av := range a {
  219. inBoth := false
  220. for _, bv := range b {
  221. if keyEqual(av, bv) {
  222. inBoth = true
  223. both = append(both, av)
  224. break
  225. }
  226. }
  227. if !inBoth {
  228. ak = append(ak, av)
  229. }
  230. }
  231. for _, bv := range b {
  232. inBoth := false
  233. for _, av := range a {
  234. if keyEqual(av, bv) {
  235. inBoth = true
  236. break
  237. }
  238. }
  239. if !inBoth {
  240. bk = append(bk, bv)
  241. }
  242. }
  243. return
  244. }