diff.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. package pretty
  2. import (
  3. "fmt"
  4. "io"
  5. "reflect"
  6. )
  7. type sbuf []string
  8. func (p *sbuf) Printf(format string, a ...interface{}) {
  9. s := fmt.Sprintf(format, a...)
  10. *p = append(*p, s)
  11. }
  12. // Diff returns a slice where each element describes
  13. // a difference between a and b.
  14. func Diff(a, b interface{}) (desc []string) {
  15. Pdiff((*sbuf)(&desc), a, b)
  16. return desc
  17. }
  18. // wprintfer calls Fprintf on w for each Printf call
  19. // with a trailing newline.
  20. type wprintfer struct{ w io.Writer }
  21. func (p *wprintfer) Printf(format string, a ...interface{}) {
  22. fmt.Fprintf(p.w, format+"\n", a...)
  23. }
  24. // Fdiff writes to w a description of the differences between a and b.
  25. func Fdiff(w io.Writer, a, b interface{}) {
  26. Pdiff(&wprintfer{w}, a, b)
  27. }
  28. type Printfer interface {
  29. Printf(format string, a ...interface{})
  30. }
  31. // Pdiff prints to p a description of the differences between a and b.
  32. // It calls Printf once for each difference, with no trailing newline.
  33. // The standard library log.Logger is a Printfer.
  34. func Pdiff(p Printfer, a, b interface{}) {
  35. d := diffPrinter{
  36. w: p,
  37. aVisited: make(map[visit]visit),
  38. bVisited: make(map[visit]visit),
  39. }
  40. d.diff(reflect.ValueOf(a), reflect.ValueOf(b))
  41. }
  42. type Logfer interface {
  43. Logf(format string, a ...interface{})
  44. }
  45. // logprintfer calls Fprintf on w for each Printf call
  46. // with a trailing newline.
  47. type logprintfer struct{ l Logfer }
  48. func (p *logprintfer) Printf(format string, a ...interface{}) {
  49. p.l.Logf(format, a...)
  50. }
  51. // Ldiff prints to l a description of the differences between a and b.
  52. // It calls Logf once for each difference, with no trailing newline.
  53. // The standard library testing.T and testing.B are Logfers.
  54. func Ldiff(l Logfer, a, b interface{}) {
  55. Pdiff(&logprintfer{l}, a, b)
  56. }
  57. type diffPrinter struct {
  58. w Printfer
  59. l string // label
  60. aVisited map[visit]visit
  61. bVisited map[visit]visit
  62. }
  63. func (w diffPrinter) printf(f string, a ...interface{}) {
  64. var l string
  65. if w.l != "" {
  66. l = w.l + ": "
  67. }
  68. w.w.Printf(l+f, a...)
  69. }
  70. func (w diffPrinter) diff(av, bv reflect.Value) {
  71. if !av.IsValid() && bv.IsValid() {
  72. w.printf("nil != %# v", formatter{v: bv, quote: true})
  73. return
  74. }
  75. if av.IsValid() && !bv.IsValid() {
  76. w.printf("%# v != nil", formatter{v: av, quote: true})
  77. return
  78. }
  79. if !av.IsValid() && !bv.IsValid() {
  80. return
  81. }
  82. at := av.Type()
  83. bt := bv.Type()
  84. if at != bt {
  85. w.printf("%v != %v", at, bt)
  86. return
  87. }
  88. if av.CanAddr() && bv.CanAddr() {
  89. avis := visit{av.UnsafeAddr(), at}
  90. bvis := visit{bv.UnsafeAddr(), bt}
  91. var cycle bool
  92. // Have we seen this value before?
  93. if vis, ok := w.aVisited[avis]; ok {
  94. cycle = true
  95. if vis != bvis {
  96. w.printf("%# v (previously visited) != %# v", formatter{v: av, quote: true}, formatter{v: bv, quote: true})
  97. }
  98. } else if _, ok := w.bVisited[bvis]; ok {
  99. cycle = true
  100. w.printf("%# v != %# v (previously visited)", formatter{v: av, quote: true}, formatter{v: bv, quote: true})
  101. }
  102. w.aVisited[avis] = bvis
  103. w.bVisited[bvis] = avis
  104. if cycle {
  105. return
  106. }
  107. }
  108. switch kind := at.Kind(); kind {
  109. case reflect.Bool:
  110. if a, b := av.Bool(), bv.Bool(); a != b {
  111. w.printf("%v != %v", a, b)
  112. }
  113. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  114. if a, b := av.Int(), bv.Int(); a != b {
  115. w.printf("%d != %d", a, b)
  116. }
  117. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  118. if a, b := av.Uint(), bv.Uint(); a != b {
  119. w.printf("%d != %d", a, b)
  120. }
  121. case reflect.Float32, reflect.Float64:
  122. if a, b := av.Float(), bv.Float(); a != b {
  123. w.printf("%v != %v", a, b)
  124. }
  125. case reflect.Complex64, reflect.Complex128:
  126. if a, b := av.Complex(), bv.Complex(); a != b {
  127. w.printf("%v != %v", a, b)
  128. }
  129. case reflect.Array:
  130. n := av.Len()
  131. for i := 0; i < n; i++ {
  132. w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
  133. }
  134. case reflect.Chan, reflect.Func, reflect.UnsafePointer:
  135. if a, b := av.Pointer(), bv.Pointer(); a != b {
  136. w.printf("%#x != %#x", a, b)
  137. }
  138. case reflect.Interface:
  139. w.diff(av.Elem(), bv.Elem())
  140. case reflect.Map:
  141. ak, both, bk := keyDiff(av.MapKeys(), bv.MapKeys())
  142. for _, k := range ak {
  143. w := w.relabel(fmt.Sprintf("[%#v]", k))
  144. w.printf("%q != (missing)", av.MapIndex(k))
  145. }
  146. for _, k := range both {
  147. w := w.relabel(fmt.Sprintf("[%#v]", k))
  148. w.diff(av.MapIndex(k), bv.MapIndex(k))
  149. }
  150. for _, k := range bk {
  151. w := w.relabel(fmt.Sprintf("[%#v]", k))
  152. w.printf("(missing) != %q", bv.MapIndex(k))
  153. }
  154. case reflect.Ptr:
  155. switch {
  156. case av.IsNil() && !bv.IsNil():
  157. w.printf("nil != %# v", formatter{v: bv, quote: true})
  158. case !av.IsNil() && bv.IsNil():
  159. w.printf("%# v != nil", formatter{v: av, quote: true})
  160. case !av.IsNil() && !bv.IsNil():
  161. w.diff(av.Elem(), bv.Elem())
  162. }
  163. case reflect.Slice:
  164. lenA := av.Len()
  165. lenB := bv.Len()
  166. if lenA != lenB {
  167. w.printf("%s[%d] != %s[%d]", av.Type(), lenA, bv.Type(), lenB)
  168. break
  169. }
  170. for i := 0; i < lenA; i++ {
  171. w.relabel(fmt.Sprintf("[%d]", i)).diff(av.Index(i), bv.Index(i))
  172. }
  173. case reflect.String:
  174. if a, b := av.String(), bv.String(); a != b {
  175. w.printf("%q != %q", a, b)
  176. }
  177. case reflect.Struct:
  178. for i := 0; i < av.NumField(); i++ {
  179. w.relabel(at.Field(i).Name).diff(av.Field(i), bv.Field(i))
  180. }
  181. default:
  182. panic("unknown reflect Kind: " + kind.String())
  183. }
  184. }
  185. func (d diffPrinter) relabel(name string) (d1 diffPrinter) {
  186. d1 = d
  187. if d.l != "" && name[0] != '[' {
  188. d1.l += "."
  189. }
  190. d1.l += name
  191. return d1
  192. }
  193. // keyEqual compares a and b for equality.
  194. // Both a and b must be valid map keys.
  195. func keyEqual(av, bv reflect.Value) bool {
  196. if !av.IsValid() && !bv.IsValid() {
  197. return true
  198. }
  199. if !av.IsValid() || !bv.IsValid() || av.Type() != bv.Type() {
  200. return false
  201. }
  202. switch kind := av.Kind(); kind {
  203. case reflect.Bool:
  204. a, b := av.Bool(), bv.Bool()
  205. return a == b
  206. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  207. a, b := av.Int(), bv.Int()
  208. return a == b
  209. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  210. a, b := av.Uint(), bv.Uint()
  211. return a == b
  212. case reflect.Float32, reflect.Float64:
  213. a, b := av.Float(), bv.Float()
  214. return a == b
  215. case reflect.Complex64, reflect.Complex128:
  216. a, b := av.Complex(), bv.Complex()
  217. return a == b
  218. case reflect.Array:
  219. for i := 0; i < av.Len(); i++ {
  220. if !keyEqual(av.Index(i), bv.Index(i)) {
  221. return false
  222. }
  223. }
  224. return true
  225. case reflect.Chan, reflect.UnsafePointer, reflect.Ptr:
  226. a, b := av.Pointer(), bv.Pointer()
  227. return a == b
  228. case reflect.Interface:
  229. return keyEqual(av.Elem(), bv.Elem())
  230. case reflect.String:
  231. a, b := av.String(), bv.String()
  232. return a == b
  233. case reflect.Struct:
  234. for i := 0; i < av.NumField(); i++ {
  235. if !keyEqual(av.Field(i), bv.Field(i)) {
  236. return false
  237. }
  238. }
  239. return true
  240. default:
  241. panic("invalid map key type " + av.Type().String())
  242. }
  243. }
  244. func keyDiff(a, b []reflect.Value) (ak, both, bk []reflect.Value) {
  245. for _, av := range a {
  246. inBoth := false
  247. for _, bv := range b {
  248. if keyEqual(av, bv) {
  249. inBoth = true
  250. both = append(both, av)
  251. break
  252. }
  253. }
  254. if !inBoth {
  255. ak = append(ak, av)
  256. }
  257. }
  258. for _, bv := range b {
  259. inBoth := false
  260. for _, av := range a {
  261. if keyEqual(av, bv) {
  262. inBoth = true
  263. break
  264. }
  265. }
  266. if !inBoth {
  267. bk = append(bk, bv)
  268. }
  269. }
  270. return
  271. }