Skip to content

Commit 41c2254

Browse files
committed
chore: Refactor mocking completions in tests
1 parent 182fbbf commit 41c2254

File tree

5 files changed

+91
-103
lines changed

5 files changed

+91
-103
lines changed

packages/navie/test/fixture.ts

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,3 @@ export function predictsSummary(): (messages: Message[]) => void {
9898
return Promise.resolve([]);
9999
});
100100
}
101-
102-
export function mockAIResponse(completionWithRetry: jest.Mock, responses: string[]): void {
103-
completionWithRetry.mockResolvedValueOnce(
104-
responses.map((response, index) => ({
105-
id: 'cmpl-3Z5z9J5Z5Z5Z5Z5Z5Z5Z5Z5Z5Z5',
106-
choices: [
107-
{
108-
delta: {
109-
content: response,
110-
},
111-
index,
112-
finish_reason: index === responses.length - 1 ? 'stop' : null,
113-
},
114-
],
115-
created: 1635989729,
116-
model: 'gpt-3.5',
117-
object: 'chat.completion.chunk',
118-
}))
119-
);
120-
}
Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,28 @@
1-
import { ChatOpenAI } from '@langchain/openai';
2-
31
import InteractionHistory from '../../src/interaction-history';
42
import ClassificationService from '../../src/services/classification-service';
5-
import { mockAIResponse } from '../fixture';
6-
import OpenAICompletionService from '../../src/services/openai-completion-service';
7-
import Trajectory from '../../src/lib/trajectory';
8-
import { TrajectoryEvent } from '../../dist/lib/trajectory';
9-
import MessageTokenReducerService from '../../src/services/message-token-reducer-service';
103

11-
jest.mock('@langchain/openai');
12-
const completionWithRetry = jest.mocked(ChatOpenAI.prototype.completionWithRetry);
4+
import MockCompletionService from './mock-completion-service';
135

146
describe('ClassificationService', () => {
157
let interactionHistory: InteractionHistory;
16-
let trajectory: Trajectory;
178
let service: ClassificationService;
9+
const completion = new MockCompletionService();
10+
const completeSpy = jest.spyOn(completion, 'complete');
1811

1912
beforeEach(() => {
2013
interactionHistory = new InteractionHistory();
2114
interactionHistory.on('event', (event) => console.log(event.message));
22-
trajectory = new Trajectory();
23-
service = new ClassificationService(
24-
interactionHistory,
25-
new OpenAICompletionService('gpt-4', 0.5, trajectory, new MessageTokenReducerService())
26-
);
15+
service = new ClassificationService(interactionHistory, completion);
2716
});
28-
afterEach(() => jest.resetAllMocks());
17+
afterEach(() => jest.restoreAllMocks());
2918

3019
describe('when LLM responds', () => {
3120
const classification = `
3221
- architecture: high
3322
- troubleshoot: medium
3423
`;
3524

36-
beforeEach(() => mockAIResponse(completionWithRetry, [classification]));
25+
beforeEach(() => completion.mock(classification));
3726

3827
it('returns the response', async () => {
3928
const response = await service.classifyQuestion('user management');
@@ -47,7 +36,7 @@ describe('ClassificationService', () => {
4736
weight: 'medium',
4837
},
4938
]);
50-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
39+
expect(completeSpy).toHaveBeenCalledTimes(1);
5140
});
5241

5342
it('emits classification event', async () => {
@@ -60,16 +49,5 @@ describe('ClassificationService', () => {
6049
classification: ['architecture=high', 'troubleshoot=medium'],
6150
});
6251
});
63-
64-
it('emits trajectory events', async () => {
65-
const trajectoryEvents = new Array<TrajectoryEvent>();
66-
67-
trajectory.on('event', (event) => trajectoryEvents.push(event));
68-
69-
await service.classifyQuestion('user management');
70-
71-
expect(trajectoryEvents.map((e) => e.message.role)).toEqual(['system', 'user', 'assistant']);
72-
expect(trajectoryEvents.map((e) => e.type)).toEqual(['sent', 'sent', 'received']);
73-
});
7452
});
7553
});
Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,20 @@
1-
import { ChatOpenAI } from '@langchain/openai';
2-
31
import InteractionHistory from '../../src/interaction-history';
4-
import ClassificationService from '../../src/services/classification-service';
5-
import { mockAIResponse } from '../fixture';
6-
import OpenAICompletionService from '../../src/services/openai-completion-service';
72
import ComputeUpdateService from '../../src/services/compute-update-service';
8-
import Trajectory from '../../src/lib/trajectory';
9-
import MessageTokenReducerService from '../../src/services/message-token-reducer-service';
103

11-
jest.mock('@langchain/openai');
12-
const completionWithRetry = jest.mocked(ChatOpenAI.prototype.completionWithRetry);
4+
import MockCompletionService from './mock-completion-service';
135

146
describe('ComputeUpdateService', () => {
157
let interactionHistory: InteractionHistory;
16-
let trajectory: Trajectory;
178
let service: ComputeUpdateService;
9+
const completion = new MockCompletionService();
10+
const complete = jest.spyOn(completion, 'complete');
1811

1912
beforeEach(() => {
2013
interactionHistory = new InteractionHistory();
2114
interactionHistory.on('event', (event) => console.log(event.message));
22-
trajectory = new Trajectory();
23-
service = new ComputeUpdateService(
24-
interactionHistory,
25-
new OpenAICompletionService('gpt-4', 0.5, trajectory, new MessageTokenReducerService())
26-
);
15+
service = new ComputeUpdateService(interactionHistory, completion);
2716
});
28-
afterEach(() => jest.resetAllMocks());
17+
afterEach(() => jest.restoreAllMocks());
2918

3019
describe('when LLM responds', () => {
3120
const existingContent = `class User < ApplicationRecord
@@ -50,12 +39,12 @@ end
5039
</change>
5140
`;
5241

53-
beforeEach(() => mockAIResponse(completionWithRetry, [changeStr]));
42+
beforeEach(() => completion.mock(changeStr));
5443

5544
it('computes the update', async () => {
5645
const response = await service.computeUpdate(existingContent, newContent);
5746
expect(response).toStrictEqual(change);
58-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
47+
expect(complete).toHaveBeenCalledTimes(1);
5948
});
6049
});
6150
});
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import assert from 'node:assert';
2+
3+
import { ZodType } from 'zod';
4+
5+
import type { Message } from '../../src';
6+
import CompletionService, { type Completion, Usage } from '../../src/services/completion-service';
7+
8+
export default class MockCompletionService implements CompletionService {
9+
// eslint-disable-next-line @typescript-eslint/require-await
10+
async *complete(messages: readonly Message[]): Completion {
11+
const completion = this.completion(messages);
12+
for (const c of completion) {
13+
yield c;
14+
}
15+
return new Usage();
16+
}
17+
18+
/**
19+
* The mock completion function. This function can be used to mock the completion result.
20+
* By default, it returns a hardcoded string split on spaces. It's a normal Jest mock, so you can manipulate the result
21+
* further using eg. the `mockReturnValue` function. mock() method is provided as a shorthand.
22+
* @param messages The messages to complete.
23+
* @returns The mocked completion.
24+
*/
25+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
26+
completion = jest.fn(function (this: MockCompletionService, messages: readonly Message[]) {
27+
return ['Example ', 'response ', 'from ', 'the ', 'LLM', '.'];
28+
});
29+
30+
/**
31+
* A shorthand to mock the completion result.
32+
* @param response The response to give to callers. This can be text or a JSON object.
33+
* @returns The mock function that can be used to manipulate the results further.
34+
*/
35+
mock(...response: string[]): typeof this.completion;
36+
mock(response: string): typeof this.completion;
37+
mock(response: unknown): typeof this.completion;
38+
mock(...response: unknown[]): typeof this.completion {
39+
if (response.length > 1 && response.every((x) => typeof x === 'string'))
40+
return this.completion.mockReturnValue(response as string[]);
41+
assert(response.length === 1, 'Only one response is supported');
42+
const cpl = typeof response[0] === 'string' ? response[0] : JSON.stringify(response[0]);
43+
return this.completion.mockReturnValue(cpl.split(/(?= )/));
44+
}
45+
46+
// eslint-disable-next-line @typescript-eslint/require-await
47+
async json<T>(messages: Message[], schema: ZodType<T>): Promise<T | undefined> {
48+
const completion = this.completion(messages).join('');
49+
return schema.parse(JSON.parse(completion));
50+
}
51+
52+
modelName = 'mock-model';
53+
miniModelName = 'mock-mini-model';
54+
temperature = 0.7;
55+
}

packages/navie/test/services/vector-terms-service.spec.ts

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,24 @@
1-
/* eslint-disable @typescript-eslint/no-unsafe-return */
2-
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
3-
4-
import { ChatOpenAI } from '@langchain/openai';
5-
61
import VectorTermsService from '../../src/services/vector-terms-service';
72
import InteractionHistory from '../../src/interaction-history';
8-
import { mockAIResponse } from '../fixture';
9-
import OpenAICompletionService from '../../src/services/openai-completion-service';
10-
import Trajectory from '../../src/lib/trajectory';
11-
import MessageTokenReducerService from '../../src/services/message-token-reducer-service';
12-
13-
jest.mock('@langchain/openai');
14-
const completionWithRetry = jest.mocked(ChatOpenAI.prototype.completionWithRetry);
3+
import MockCompletionService from './mock-completion-service';
154

165
describe('VectorTermsService', () => {
176
let interactionHistory: InteractionHistory;
187
let service: VectorTermsService;
19-
let trajectory: Trajectory;
8+
const completion = new MockCompletionService();
9+
const complete = jest.spyOn(completion, 'complete');
2010

2111
beforeEach(() => {
2212
interactionHistory = new InteractionHistory();
2313
interactionHistory.on('event', (event) => console.log(event.message));
24-
trajectory = new Trajectory();
25-
service = new VectorTermsService(
26-
interactionHistory,
27-
new OpenAICompletionService('gpt-4', 0.5, trajectory, new MessageTokenReducerService())
28-
);
14+
service = new VectorTermsService(interactionHistory, completion);
2915
});
30-
afterEach(() => jest.resetAllMocks());
16+
afterEach(() => jest.restoreAllMocks());
3117

3218
describe('when LLM suggested terms', () => {
3319
describe('is a valid JSON object', () => {
3420
it('is recorded in the interaction history', async () => {
35-
mockAIResponse(completionWithRetry, [`{"terms": ["user", "management"]}`]);
21+
completion.mock(`{"terms": ["user", "management"]}`);
3622
await service.suggestTerms('user management');
3723
expect(interactionHistory.events.map((e) => ({ ...e }))).toEqual([
3824
{
@@ -42,73 +28,73 @@ describe('VectorTermsService', () => {
4228
]);
4329
});
4430
it('should return the terms', async () => {
45-
mockAIResponse(completionWithRetry, [`{"terms": ["user", "management"]}`]);
31+
completion.mock(`{"terms": ["user", "management"]}`);
4632
const terms = await service.suggestTerms('user management');
4733
expect(terms).toEqual(['user', 'management']);
48-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
34+
expect(complete).toHaveBeenCalledTimes(1);
4935
});
5036
it('removes very short terms', async () => {
51-
mockAIResponse(completionWithRetry, [`["user", "management", "a"]`]);
37+
completion.mock(`["user", "management", "a"]`);
5238
const terms = await service.suggestTerms('user management');
5339
expect(terms).toEqual(['user', 'management', 'a']);
54-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
40+
expect(complete).toHaveBeenCalledTimes(1);
5541
});
5642
it('converts underscore_words to distinct words', async () => {
57-
mockAIResponse(completionWithRetry, [`["user_management"]`]);
43+
completion.mock(`["user_management"]`);
5844
const terms = await service.suggestTerms('user management');
5945
expect(terms).toEqual(['user_management']);
60-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
46+
expect(complete).toHaveBeenCalledTimes(1);
6147
});
6248
});
6349

6450
describe('are a valid JSON list', () => {
6551
it('should return the terms', async () => {
66-
mockAIResponse(completionWithRetry, ['["user", "management"]']);
52+
completion.mock('["user", "management"]');
6753
const terms = await service.suggestTerms('user management');
6854
expect(terms).toEqual(['user', 'management']);
6955
});
7056
});
7157

7258
describe('are valid JSON wrapped in fences', () => {
7359
it('should return the terms', async () => {
74-
mockAIResponse(completionWithRetry, ['```json\n', '["user", "management"]\n', '```\n']);
60+
completion.mock('```json\n["user", "management"]\n```\n');
7561
const terms = await service.suggestTerms('user management');
7662
expect(terms).toEqual(['user', 'management']);
7763
});
7864
});
7965

8066
describe('is YAML', () => {
8167
it('parses the terms', async () => {
82-
mockAIResponse(completionWithRetry, ['response_key:\n', ' - user\n', ' - management\n']);
68+
completion.mock('response_key:\n', ' - user\n', ' - management\n');
8369
const terms = await service.suggestTerms('user management');
8470
expect(terms).toEqual(['response_key:', '-', 'user', 'management']);
8571
});
8672
});
8773

8874
describe('is prefixed by "Terms:"', () => {
8975
it('is accepted and processed', async () => {
90-
mockAIResponse(completionWithRetry, ['Terms: ["user", "management"]']);
76+
completion.mock('Terms: ["user", "management"]');
9177
const terms = await service.suggestTerms('user management');
9278
expect(terms).toEqual(['user', 'management']);
93-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
79+
expect(complete).toHaveBeenCalledTimes(1);
9480
});
9581
});
9682

9783
describe('includes terms with "+" prefix', () => {
9884
it('is accepted and processed', async () => {
99-
mockAIResponse(completionWithRetry, ['Terms: +user management']);
85+
completion.mock('Terms: +user management');
10086
const terms = await service.suggestTerms('user management');
10187
expect(terms).toEqual(['+user', 'management']);
102-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
88+
expect(complete).toHaveBeenCalledTimes(1);
10389
});
10490
});
10591

106-
describe('is list-ish ', () => {
92+
describe('is list-ish', () => {
10793
it('is accepted and processed', async () => {
108-
mockAIResponse(completionWithRetry, ['-user -mgmt']);
94+
completion.mock('-user -mgmt');
10995
const terms = await service.suggestTerms('user management');
11096
expect(terms).toEqual(['-user', '-mgmt']);
111-
expect(completionWithRetry).toHaveBeenCalledTimes(1);
97+
expect(complete).toHaveBeenCalledTimes(1);
11298
});
11399
});
114100
});

0 commit comments

Comments
 (0)