• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2023 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package proptools
16
17import (
18	"cmp"
19	"encoding/binary"
20	"fmt"
21	"hash"
22	"hash/fnv"
23	"math"
24	"reflect"
25	"slices"
26	"unsafe"
27)
28
29// byte to insert between elements of lists, fields of structs/maps, etc in order
30// to try and make sure the hash is different when values are moved around between
31// elements. 36 is arbitrary, but it's the ascii code for a record separator
32var recordSeparator []byte = []byte{36}
33
34func CalculateHash(value interface{}) (uint64, error) {
35	hasher := hasher{
36		Hash64:   fnv.New64(),
37		int64Buf: make([]byte, 8),
38	}
39	v := reflect.ValueOf(value)
40	var err error
41	if v.IsValid() {
42		err = hasher.calculateHash(v)
43	}
44	return hasher.Sum64(), err
45}
46
47type hasher struct {
48	hash.Hash64
49	int64Buf      []byte
50	ptrs          map[uintptr]bool
51	mapStateCache *mapState
52}
53
54type mapState struct {
55	indexes []int
56	keys    []reflect.Value
57	values  []reflect.Value
58}
59
60func (hasher *hasher) writeUint64(i uint64) {
61	binary.LittleEndian.PutUint64(hasher.int64Buf, i)
62	hasher.Write(hasher.int64Buf)
63}
64
65func (hasher *hasher) writeInt(i int) {
66	hasher.writeUint64(uint64(i))
67}
68
69func (hasher *hasher) writeByte(i byte) {
70	hasher.int64Buf[0] = i
71	hasher.Write(hasher.int64Buf[:1])
72}
73
74func (hasher *hasher) getMapState(size int) *mapState {
75	s := hasher.mapStateCache
76	// Clear hasher.mapStateCache so that any recursive uses don't collide with this frame.
77	hasher.mapStateCache = nil
78
79	if s == nil {
80		s = &mapState{}
81	}
82
83	// Reset the slices to length `size` and capacity at least `size`
84	s.indexes = slices.Grow(s.indexes[:0], size)[0:size]
85	s.keys = slices.Grow(s.keys[:0], size)[0:size]
86	s.values = slices.Grow(s.values[:0], size)[0:size]
87
88	return s
89}
90
91func (hasher *hasher) putMapState(s *mapState) {
92	if hasher.mapStateCache == nil || cap(hasher.mapStateCache.indexes) < cap(s.indexes) {
93		hasher.mapStateCache = s
94	}
95}
96
97func (hasher *hasher) calculateHash(v reflect.Value) error {
98	hasher.writeUint64(uint64(v.Kind()))
99	v.IsValid()
100	switch v.Kind() {
101	case reflect.Struct:
102		l := v.NumField()
103		hasher.writeInt(l)
104		for i := 0; i < l; i++ {
105			hasher.Write(recordSeparator)
106			err := hasher.calculateHash(v.Field(i))
107			if err != nil {
108				return fmt.Errorf("in field %s: %s", v.Type().Field(i).Name, err.Error())
109			}
110		}
111	case reflect.Map:
112		l := v.Len()
113		hasher.writeInt(l)
114		iter := v.MapRange()
115		s := hasher.getMapState(l)
116		for i := 0; iter.Next(); i++ {
117			s.indexes[i] = i
118			s.keys[i] = iter.Key()
119			s.values[i] = iter.Value()
120		}
121		slices.SortFunc(s.indexes, func(i, j int) int {
122			return compare_values(s.keys[i], s.keys[j])
123		})
124		for i := 0; i < l; i++ {
125			hasher.Write(recordSeparator)
126			err := hasher.calculateHash(s.keys[s.indexes[i]])
127			if err != nil {
128				return fmt.Errorf("in map: %s", err.Error())
129			}
130			hasher.Write(recordSeparator)
131			err = hasher.calculateHash(s.keys[s.indexes[i]])
132			if err != nil {
133				return fmt.Errorf("in map: %s", err.Error())
134			}
135		}
136		hasher.putMapState(s)
137	case reflect.Slice, reflect.Array:
138		l := v.Len()
139		hasher.writeInt(l)
140		for i := 0; i < l; i++ {
141			hasher.Write(recordSeparator)
142			err := hasher.calculateHash(v.Index(i))
143			if err != nil {
144				return fmt.Errorf("in %s at index %d: %s", v.Kind().String(), i, err.Error())
145			}
146		}
147	case reflect.Pointer:
148		if v.IsNil() {
149			hasher.writeByte(0)
150			return nil
151		}
152		// Hardcoded value to indicate it is a pointer
153		hasher.writeInt(0x55)
154		addr := v.Pointer()
155		if hasher.ptrs == nil {
156			hasher.ptrs = make(map[uintptr]bool)
157		}
158		if _, ok := hasher.ptrs[addr]; ok {
159			// We could make this an error if we want to disallow pointer cycles in the future
160			return nil
161		}
162		hasher.ptrs[addr] = true
163		err := hasher.calculateHash(v.Elem())
164		if err != nil {
165			return fmt.Errorf("in pointer: %s", err.Error())
166		}
167	case reflect.Interface:
168		if v.IsNil() {
169			hasher.writeByte(0)
170		} else {
171			// The only way get the pointer out of an interface to hash it or check for cycles
172			// would be InterfaceData(), but that's deprecated and seems like it has undefined behavior.
173			err := hasher.calculateHash(v.Elem())
174			if err != nil {
175				return fmt.Errorf("in interface: %s", err.Error())
176			}
177		}
178	case reflect.String:
179		strLen := len(v.String())
180		if strLen == 0 {
181			// unsafe.StringData is unspecified in this case
182			hasher.writeByte(0)
183			return nil
184		}
185		hasher.Write(unsafe.Slice(unsafe.StringData(v.String()), strLen))
186	case reflect.Bool:
187		if v.Bool() {
188			hasher.writeByte(1)
189		} else {
190			hasher.writeByte(0)
191		}
192	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
193		hasher.writeUint64(v.Uint())
194	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
195		hasher.writeUint64(uint64(v.Int()))
196	case reflect.Float32, reflect.Float64:
197		hasher.writeUint64(math.Float64bits(v.Float()))
198	default:
199		return fmt.Errorf("data may only contain primitives, strings, arrays, slices, structs, maps, and pointers, found: %s", v.Kind().String())
200	}
201	return nil
202}
203
204func compare_values(x, y reflect.Value) int {
205	if x.Type() != y.Type() {
206		panic("Expected equal types")
207	}
208
209	switch x.Kind() {
210	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
211		return cmp.Compare(x.Uint(), y.Uint())
212	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
213		return cmp.Compare(x.Int(), y.Int())
214	case reflect.Float32, reflect.Float64:
215		return cmp.Compare(x.Float(), y.Float())
216	case reflect.String:
217		return cmp.Compare(x.String(), y.String())
218	case reflect.Bool:
219		if x.Bool() == y.Bool() {
220			return 0
221		} else if x.Bool() {
222			return 1
223		} else {
224			return -1
225		}
226	case reflect.Pointer:
227		return cmp.Compare(x.Pointer(), y.Pointer())
228	case reflect.Array:
229		l := x.Len()
230		for i := 0; i < l; i++ {
231			if result := compare_values(x.Index(i), y.Index(i)); result != 0 {
232				return result
233			}
234		}
235		return 0
236	case reflect.Struct:
237		l := x.NumField()
238		for i := 0; i < l; i++ {
239			if result := compare_values(x.Field(i), y.Field(i)); result != 0 {
240				return result
241			}
242		}
243		return 0
244	case reflect.Interface:
245		if x.IsNil() && y.IsNil() {
246			return 0
247		} else if x.IsNil() {
248			return 1
249		} else if y.IsNil() {
250			return -1
251		}
252		return compare_values(x.Elem(), y.Elem())
253	default:
254		panic(fmt.Sprintf("Could not compare types %s and %s", x.Type().String(), y.Type().String()))
255	}
256}
257
258func ContainsConfigurable(value interface{}) bool {
259	ptrs := make(map[uintptr]bool)
260	v := reflect.ValueOf(value)
261	if v.IsValid() {
262		return containsConfigurableInternal(v, ptrs)
263	}
264	return false
265}
266
267func containsConfigurableInternal(v reflect.Value, ptrs map[uintptr]bool) bool {
268	switch v.Kind() {
269	case reflect.Struct:
270		t := v.Type()
271		if IsConfigurable(t) {
272			return true
273		}
274		typeFields := typeFields(t)
275		for i := 0; i < v.NumField(); i++ {
276			if HasTag(typeFields[i], "blueprint", "allow_configurable_in_provider") {
277				continue
278			}
279			if containsConfigurableInternal(v.Field(i), ptrs) {
280				return true
281			}
282		}
283	case reflect.Map:
284		iter := v.MapRange()
285		for iter.Next() {
286			key := iter.Key()
287			value := iter.Value()
288			if containsConfigurableInternal(key, ptrs) {
289				return true
290			}
291			if containsConfigurableInternal(value, ptrs) {
292				return true
293			}
294		}
295	case reflect.Slice, reflect.Array:
296		l := v.Len()
297		for i := 0; i < l; i++ {
298			if containsConfigurableInternal(v.Index(i), ptrs) {
299				return true
300			}
301		}
302	case reflect.Pointer:
303		if v.IsNil() {
304			return false
305		}
306		addr := v.Pointer()
307		if _, ok := ptrs[addr]; ok {
308			// pointer cycle
309			return false
310		}
311		ptrs[addr] = true
312		if containsConfigurableInternal(v.Elem(), ptrs) {
313			return true
314		}
315	case reflect.Interface:
316		if v.IsNil() {
317			return false
318		} else {
319			// The only way get the pointer out of an interface to hash it or check for cycles
320			// would be InterfaceData(), but that's deprecated and seems like it has undefined behavior.
321			if containsConfigurableInternal(v.Elem(), ptrs) {
322				return true
323			}
324		}
325	default:
326		return false
327	}
328	return false
329}
330