• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1'use strict';
2const {
3  ArrayPrototypePush,
4  ArrayPrototypeSlice,
5  Error,
6  FunctionPrototypeCall,
7  ObjectDefineProperty,
8  ObjectGetOwnPropertyDescriptor,
9  ObjectGetPrototypeOf,
10  Proxy,
11  ReflectApply,
12  ReflectConstruct,
13  ReflectGet,
14  SafeMap,
15} = primordials;
16const {
17  codes: {
18    ERR_INVALID_ARG_TYPE,
19    ERR_INVALID_ARG_VALUE,
20  },
21} = require('internal/errors');
22const { kEmptyObject } = require('internal/util');
23const {
24  validateBoolean,
25  validateFunction,
26  validateInteger,
27  validateObject,
28} = require('internal/validators');
29
30function kDefaultFunction() {}
31
32class MockFunctionContext {
33  #calls;
34  #mocks;
35  #implementation;
36  #restore;
37  #times;
38
39  constructor(implementation, restore, times) {
40    this.#calls = [];
41    this.#mocks = new SafeMap();
42    this.#implementation = implementation;
43    this.#restore = restore;
44    this.#times = times;
45  }
46
47  get calls() {
48    return ArrayPrototypeSlice(this.#calls, 0);
49  }
50
51  callCount() {
52    return this.#calls.length;
53  }
54
55  mockImplementation(implementation) {
56    validateFunction(implementation, 'implementation');
57    this.#implementation = implementation;
58  }
59
60  mockImplementationOnce(implementation, onCall) {
61    validateFunction(implementation, 'implementation');
62    const nextCall = this.#calls.length;
63    const call = onCall ?? nextCall;
64    validateInteger(call, 'onCall', nextCall);
65    this.#mocks.set(call, implementation);
66  }
67
68  restore() {
69    const { descriptor, object, original, methodName } = this.#restore;
70
71    if (typeof methodName === 'string') {
72      // This is an object method spy.
73      ObjectDefineProperty(object, methodName, descriptor);
74    } else {
75      // This is a bare function spy. There isn't much to do here but make
76      // the mock call the original function.
77      this.#implementation = original;
78    }
79  }
80
81  resetCalls() {
82    this.#calls = [];
83  }
84
85  trackCall(call) {
86    ArrayPrototypePush(this.#calls, call);
87  }
88
89  nextImpl() {
90    const nextCall = this.#calls.length;
91    const mock = this.#mocks.get(nextCall);
92    const impl = mock ?? this.#implementation;
93
94    if (nextCall + 1 === this.#times) {
95      this.restore();
96    }
97
98    this.#mocks.delete(nextCall);
99    return impl;
100  }
101}
102
103const { nextImpl, restore, trackCall } = MockFunctionContext.prototype;
104delete MockFunctionContext.prototype.trackCall;
105delete MockFunctionContext.prototype.nextImpl;
106
107class MockTracker {
108  #mocks = [];
109
110  fn(
111    original = function() {},
112    implementation = original,
113    options = kEmptyObject,
114  ) {
115    if (original !== null && typeof original === 'object') {
116      options = original;
117      original = function() {};
118      implementation = original;
119    } else if (implementation !== null && typeof implementation === 'object') {
120      options = implementation;
121      implementation = original;
122    }
123
124    validateFunction(original, 'original');
125    validateFunction(implementation, 'implementation');
126    validateObject(options, 'options');
127    const { times = Infinity } = options;
128    validateTimes(times, 'options.times');
129    const ctx = new MockFunctionContext(implementation, { original }, times);
130    return this.#setupMock(ctx, original);
131  }
132
133  method(
134    objectOrFunction,
135    methodName,
136    implementation = kDefaultFunction,
137    options = kEmptyObject,
138  ) {
139    validateStringOrSymbol(methodName, 'methodName');
140    if (typeof objectOrFunction !== 'function') {
141      validateObject(objectOrFunction, 'object');
142    }
143
144    if (implementation !== null && typeof implementation === 'object') {
145      options = implementation;
146      implementation = kDefaultFunction;
147    }
148
149    validateFunction(implementation, 'implementation');
150    validateObject(options, 'options');
151
152    const {
153      getter = false,
154      setter = false,
155      times = Infinity,
156    } = options;
157
158    validateBoolean(getter, 'options.getter');
159    validateBoolean(setter, 'options.setter');
160    validateTimes(times, 'options.times');
161
162    if (setter && getter) {
163      throw new ERR_INVALID_ARG_VALUE(
164        'options.setter', setter, "cannot be used with 'options.getter'",
165      );
166    }
167    const descriptor = findMethodOnPrototypeChain(objectOrFunction, methodName);
168
169    let original;
170
171    if (getter) {
172      original = descriptor?.get;
173    } else if (setter) {
174      original = descriptor?.set;
175    } else {
176      original = descriptor?.value;
177    }
178
179    if (typeof original !== 'function') {
180      throw new ERR_INVALID_ARG_VALUE(
181        'methodName', original, 'must be a method',
182      );
183    }
184
185    const restore = { descriptor, object: objectOrFunction, methodName };
186    const impl = implementation === kDefaultFunction ?
187      original : implementation;
188    const ctx = new MockFunctionContext(impl, restore, times);
189    const mock = this.#setupMock(ctx, original);
190    const mockDescriptor = {
191      __proto__: null,
192      configurable: descriptor.configurable,
193      enumerable: descriptor.enumerable,
194    };
195
196    if (getter) {
197      mockDescriptor.get = mock;
198      mockDescriptor.set = descriptor.set;
199    } else if (setter) {
200      mockDescriptor.get = descriptor.get;
201      mockDescriptor.set = mock;
202    } else {
203      mockDescriptor.writable = descriptor.writable;
204      mockDescriptor.value = mock;
205    }
206
207    ObjectDefineProperty(objectOrFunction, methodName, mockDescriptor);
208
209    return mock;
210  }
211
212  getter(
213    object,
214    methodName,
215    implementation = kDefaultFunction,
216    options = kEmptyObject,
217  ) {
218    if (implementation !== null && typeof implementation === 'object') {
219      options = implementation;
220      implementation = kDefaultFunction;
221    } else {
222      validateObject(options, 'options');
223    }
224
225    const { getter = true } = options;
226
227    if (getter === false) {
228      throw new ERR_INVALID_ARG_VALUE(
229        'options.getter', getter, 'cannot be false',
230      );
231    }
232
233    return this.method(object, methodName, implementation, {
234      ...options,
235      getter,
236    });
237  }
238
239  setter(
240    object,
241    methodName,
242    implementation = kDefaultFunction,
243    options = kEmptyObject,
244  ) {
245    if (implementation !== null && typeof implementation === 'object') {
246      options = implementation;
247      implementation = kDefaultFunction;
248    } else {
249      validateObject(options, 'options');
250    }
251
252    const { setter = true } = options;
253
254    if (setter === false) {
255      throw new ERR_INVALID_ARG_VALUE(
256        'options.setter', setter, 'cannot be false',
257      );
258    }
259
260    return this.method(object, methodName, implementation, {
261      ...options,
262      setter,
263    });
264  }
265
266  reset() {
267    this.restoreAll();
268    this.#mocks = [];
269  }
270
271  restoreAll() {
272    for (let i = 0; i < this.#mocks.length; i++) {
273      FunctionPrototypeCall(restore, this.#mocks[i]);
274    }
275  }
276
277  #setupMock(ctx, fnToMatch) {
278    const mock = new Proxy(fnToMatch, {
279      __proto__: null,
280      apply(_fn, thisArg, argList) {
281        const fn = FunctionPrototypeCall(nextImpl, ctx);
282        let result;
283        let error;
284
285        try {
286          result = ReflectApply(fn, thisArg, argList);
287        } catch (err) {
288          error = err;
289          throw err;
290        } finally {
291          FunctionPrototypeCall(trackCall, ctx, {
292            arguments: argList,
293            error,
294            result,
295            // eslint-disable-next-line no-restricted-syntax
296            stack: new Error(),
297            target: undefined,
298            this: thisArg,
299          });
300        }
301
302        return result;
303      },
304      construct(target, argList, newTarget) {
305        const realTarget = FunctionPrototypeCall(nextImpl, ctx);
306        let result;
307        let error;
308
309        try {
310          result = ReflectConstruct(realTarget, argList, newTarget);
311        } catch (err) {
312          error = err;
313          throw err;
314        } finally {
315          FunctionPrototypeCall(trackCall, ctx, {
316            arguments: argList,
317            error,
318            result,
319            // eslint-disable-next-line no-restricted-syntax
320            stack: new Error(),
321            target,
322            this: result,
323          });
324        }
325
326        return result;
327      },
328      get(target, property, receiver) {
329        if (property === 'mock') {
330          return ctx;
331        }
332
333        return ReflectGet(target, property, receiver);
334      },
335    });
336
337    ArrayPrototypePush(this.#mocks, ctx);
338    return mock;
339  }
340}
341
342function validateStringOrSymbol(value, name) {
343  if (typeof value !== 'string' && typeof value !== 'symbol') {
344    throw new ERR_INVALID_ARG_TYPE(name, ['string', 'symbol'], value);
345  }
346}
347
348function validateTimes(value, name) {
349  if (value === Infinity) {
350    return;
351  }
352
353  validateInteger(value, name, 1);
354}
355
356function findMethodOnPrototypeChain(instance, methodName) {
357  let host = instance;
358  let descriptor;
359
360  while (host !== null) {
361    descriptor = ObjectGetOwnPropertyDescriptor(host, methodName);
362
363    if (descriptor) {
364      break;
365    }
366
367    host = ObjectGetPrototypeOf(host);
368  }
369
370  return descriptor;
371}
372
373module.exports = { MockTracker };
374