Skip to content

Commit d8d4522

Browse files
authored
Merge pull request #113 from cloudera/DSE-47416
DSE-47416 - Add Model Provider UI
2 parents e86f8e8 + 0867ee1 commit d8d4522

File tree

19 files changed

+913
-18
lines changed

19 files changed

+913
-18
lines changed

app/client/src/Container.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ const pages: MenuItem[] = [
9595
<Link to={`${Pages.EXPORTS}`}>{LABELS[Pages.EXPORTS]}</Link>
9696
),
9797
},
98+
{
99+
key: Pages.SETTINGS,
100+
label: (
101+
<Link to={`${Pages.SETTINGS}`}>{LABELS[Pages.SETTINGS]}</Link>
102+
),
103+
},
98104

99105
// {
100106
// key: Pages.TELEMETRY,

app/client/src/components/JobStatus/jobStatusIcon.tsx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ const IconWrapper = styled.div`
2424
}
2525
`
2626

27+
const StyledIconWrapper = styled(IconWrapper)`
28+
svg {
29+
color: #008cff;
30+
}
31+
`;
32+
2733

2834
export default function JobStatusIcon({ status, customTooltipTitles }: JobStatusProps) {
2935
const tooltipTitles = {...defaultTooltipTitles, ...customTooltipTitles};
@@ -44,11 +50,11 @@ export default function JobStatusIcon({ status, customTooltipTitles }: JobStatus
4450
</Tooltip>;
4551
case 'ENGINE_SCHEDULING':
4652
return <Tooltip title={tooltipTitles.ENGINE_SCHEDULING}>
47-
<IconWrapper><LoadingOutlined spin/></IconWrapper>
53+
<StyledIconWrapper><LoadingOutlined spin/></StyledIconWrapper>
4854
</Tooltip>;
4955
case 'ENGINE_RUNNING':
5056
return <Tooltip title={tooltipTitles.ENGINE_RUNNING}>
51-
<IconWrapper><LoadingOutlined spin /></IconWrapper>
57+
<StyledIconWrapper><LoadingOutlined spin /></StyledIconWrapper>
5258
</Tooltip>;
5359
case null:
5460
return <Tooltip title={tooltipTitles.null}>

app/client/src/constants.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export const LABELS = {
1212
[Pages.EXPORTS]: 'Exports',
1313
[Pages.HISTORY]: 'History',
1414
[Pages.FEEDBACK]: 'Feedback',
15+
[Pages.SETTINGS]: 'Settings',
1516
//[Pages.TELEMETRY]: 'Telemetry',
1617
[ModelParameters.TEMPERATURE]: 'Temperature',
1718
[ModelParameters.TOP_K]: 'Top K',

app/client/src/pages/DataGenerator/Configure.tsx

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@ import endsWith from 'lodash/endsWith';
22
import isEmpty from 'lodash/isEmpty';
33
import isFunction from 'lodash/isFunction';
44
import { FunctionComponent, useEffect, useState } from 'react';
5-
import { Flex, Form, FormInstance, Input, Select, Typography } from 'antd';
5+
import { Flex, Form, Input, Select, Typography } from 'antd';
66
import styled from 'styled-components';
77
import { File, WorkflowType } from './types';
88
import { useFetchModels } from '../../api/api';
99
import { MODEL_PROVIDER_LABELS } from './constants';
1010
import { ModelProviders, ModelProvidersDropdownOpts } from './types';
11-
import { getWizardModel, getWizardModeType, useWizardCtx } from './utils';
11+
import { getWizardModeType, useWizardCtx } from './utils';
1212
import FileSelectorButton from './FileSelectorButton';
1313
import UseCaseSelector from './UseCaseSelector';
1414
import { useLocation, useParams } from 'react-router-dom';
1515
import { WizardModeType } from '../../types';
16+
import get from 'lodash/get';
17+
import forEach from 'lodash/forEach';
18+
import { useModelProviders } from '../Settings/hooks';
19+
import { ModelProviderType } from '../Settings/AddModelProviderButton';
20+
import { CustomModel } from '../Settings/SettingsPage';
21+
import filter from 'lodash/filter';
1622

1723

1824
const StepContainer = styled(Flex)`
@@ -47,14 +53,21 @@ export const WORKFLOW_OPTIONS = [
4753
export const MODEL_TYPE_OPTIONS: ModelProvidersDropdownOpts = [
4854
{ label: MODEL_PROVIDER_LABELS[ModelProviders.BEDROCK], value: ModelProviders.BEDROCK},
4955
{ label: MODEL_PROVIDER_LABELS[ModelProviders.CAII], value: ModelProviders.CAII },
56+
{ label: MODEL_PROVIDER_LABELS[ModelProviders.OPENAI], value: ModelProviders.OPENAI },
57+
{ label: MODEL_PROVIDER_LABELS[ModelProviders.GEMINI], value: ModelProviders.GEMINI },
5058
];
5159

5260
const Configure: FunctionComponent = () => {
5361
const form = Form.useFormInstance();
5462
const formData = Form.useWatch((values) => values, form);
5563
const location = useLocation();
5664
const { template_name, generate_file_name } = useParams();
65+
const [models, setModels] = useState<string[]>([])
5766
const [wizardModeType, setWizardModeType] = useState(getWizardModeType(location));
67+
const { data } = useFetchModels();
68+
const customModelPrividersReq = useModelProviders();
69+
const customModels = get(customModelPrividersReq, 'data.endpoints', []);
70+
console.log('customModels', customModels);
5871

5972
useEffect(() => {
6073
if (wizardModeType === WizardModeType.DATA_AUGMENTATION) {
@@ -77,10 +90,18 @@ const Configure: FunctionComponent = () => {
7790
}
7891
}, [template_name]);
7992

93+
useEffect(() => {
94+
// set model providers
95+
// set model ids
96+
if (formData && (formData?.inference_type === ModelProviderType.OPENAI || formData?.inference_type === ModelProviderType.GEMINI) && isEmpty(generate_file_name)) {
97+
form.setFieldValue('inference_type', ModelProviders.OPENAI);
98+
}
99+
100+
}, [customModels, formData]);
101+
80102

81103
// let formData = Form.useWatch((values) => values, form);
82104
const { setIsStepValid } = useWizardCtx();
83-
const { data } = useFetchModels();
84105
const [selectedFiles, setSelectedFiles] = useState(
85106
!isEmpty(form.getFieldValue('doc_paths')) ? form.getFieldValue('doc_paths') : []);
86107

@@ -104,7 +125,6 @@ const Configure: FunctionComponent = () => {
104125

105126

106127
useEffect(() => {
107-
console.log('useEffect 1');
108128
if (formData && formData?.inference_type === undefined && isEmpty(generate_file_name)) {
109129
form.setFieldValue('inference_type', ModelProviders.CAII);
110130
setTimeout(() => {
@@ -155,6 +175,20 @@ const Configure: FunctionComponent = () => {
155175
}
156176
}
157177

178+
const onModelProviderChange = (value: string) => {
179+
form.setFieldValue('model_id', undefined)
180+
console.log('value', value);
181+
if (ModelProviderType.OPENAI === value) {
182+
const _models = filter(customModels, (model: CustomModel) => model.provider_type === ModelProviderType.OPENAI);
183+
setModels(_models.map((_model: CustomModel) => _model.model_id));
184+
} else if (ModelProviderType.GEMINI === value) {
185+
const _models = filter(customModels, (model: CustomModel) => model.provider_type === ModelProviderType.GEMINI);
186+
setModels(_models.map((_model: CustomModel) => _model.model_id));
187+
}
188+
}
189+
console.log('models', models);
190+
191+
158192
return (
159193
<StepContainer justify='center'>
160194
<FormContainer vertical>
@@ -178,7 +212,7 @@ const Configure: FunctionComponent = () => {
178212
>
179213
<Select
180214

181-
onChange={() => form.setFieldValue('model_id', undefined)}
215+
onChange={(value: string) => onModelProviderChange(value)}
182216
placeholder={'Select a model provider'}
183217
>
184218
{MODEL_TYPE_OPTIONS.map(({ label, value }, i) =>
@@ -200,15 +234,22 @@ const Configure: FunctionComponent = () => {
200234
{formData?.inference_type === ModelProviders.CAII ? (
201235
<Input placeholder={'Enter Cloudera AI Inference Model ID'}/>
202236
) : (
203-
<Select placeholder={'Select a Model'} notFoundContent={'You must select a Model Provider before selecting a Model'}>
204-
{!isEmpty(data?.models) && data?.models[ModelProviders.BEDROCK]?.map((model, i) =>
237+
<Select
238+
placeholder={'Select a Model'}
239+
notFoundContent={'You must select a Model Provider before selecting a Model'}
240+
>
241+
{formData?.inference_type === ModelProviders.BEDROCK && data?.models?.[ModelProviders.BEDROCK]?.map((model, i) => (
205242
<Select.Option key={`${model}-${i}`} value={model}>
206243
{model}
207244
</Select.Option>
208-
)}
245+
))}
246+
{(formData?.inference_type === ModelProviders.OPENAI || formData?.inference_type === ModelProviders.GEMINI) && models?.map((model, i) => (
247+
<Select.Option key={`${model}-${i}`} value={model}>
248+
{model}
249+
</Select.Option>
250+
))}
209251
</Select>
210252
)}
211-
212253
</Form.Item>
213254
{formData?.inference_type === ModelProviders.CAII && (
214255
<>

app/client/src/pages/DataGenerator/Examples.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,13 @@ const Examples: FunctionComponent = () => {
132132
};
133133

134134
const showEmptyState = (workflowType === WorkflowType.FREE_FORM_DATA_GENERATION &&
135-
isEmpty(mutation.data) &&
135+
isEmpty(mutation.data) && Array.isArray(records) &&
136136
records.length === 0) ||
137137
(form.getFieldValue('use_case') === 'custom' &&
138138
isEmpty(form.getFieldValue('examples')));
139139

140140

141+
console.log('records', records);
141142
return (
142143
<Container>
143144
{mutation?.isPending || restore_mutation.isPending && <Loading />}

app/client/src/pages/DataGenerator/Finish.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,14 @@ const Finish = () => {
322322
const rawData = genDatasetResp !== null && hasTopics(genDatasetResp) ?
323323
getRawData(genDatasetResp) : genDatasetResp?.results
324324

325+
console.log('Finish >> ', isDemo);
326+
325327
return (
326328
<div>
327329
<Title level={2}>
328330
<Flex align='center' gap={10}>
329331
<CheckCircleIcon style={{ color: '#178718' }}/>
330-
{'Success'}
332+
{isDemo ? 'Success' : 'Job Successfully Started'}
331333
</Flex>
332334
</Title>
333335
{isDemo ? (

app/client/src/pages/DataGenerator/FreeFormExampleTable.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ interface Props {
4848
const FreeFormExampleTable: FunctionComponent<Props> = ({ data }) => {
4949
const [colDefs, setColDefs] = useState([]);
5050
const [rowData, setRowData] = useState([]);
51+
console.log('FreeFormExampleTable', data);
5152

5253
useEffect(() => {
5354
if (!isEmpty(data)) {

app/client/src/pages/DataGenerator/Success.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ const Success: FC<SuccessProps> = ({ formData, isDemo = true }) => {
122122
<Title level={2}>
123123
<Flex align='center' gap={10}>
124124
<CheckCircleIcon style={{ color: '#178718' }}/>
125-
{'Success'}
125+
{isDemo ? 'Success' : 'Job successfully started.'}
126126
</Flex>
127127
</Title>
128128
{isDemo ? (

app/client/src/pages/DataGenerator/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ export const MODEL_PROVIDER_LABELS = {
55
[ModelProviders.CAII]: 'Cloudera AI Inference Service',
66
[ModelProviders.GOOGLE_GEMINI]: 'Google Gemini',
77
[ModelProviders.AZURE_OPENAI]: 'Azure OpenAI',
8+
[ModelProviders.GEMINI]: 'Gemini',
9+
[ModelProviders.OPENAI]: 'OpenAI'
810
};
911

1012
export const MIN_SEED_INSTRUCTIONS = 1

app/client/src/pages/DataGenerator/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ export enum ModelProviders {
1919
CAII = 'CAII',
2020
AZURE_OPENAI = 'AZURE_OPENAI',
2121
GOOGLE_GEMINI = 'GOOGLE_GEMINI',
22+
OPENAI = 'openai',
23+
GEMINI = 'gemini',
2224
}
2325

2426
export type ModelProvidersDropdownOpts = { label: string, value: ModelProviders }[];

0 commit comments

Comments
 (0)