@@ -2,17 +2,23 @@ import endsWith from 'lodash/endsWith';
22import isEmpty from 'lodash/isEmpty' ;
33import isFunction from 'lodash/isFunction' ;
44import { FunctionComponent , useEffect , useState } from 'react' ;
5- import { Flex , Form , FormInstance , Input , Select , Typography } from 'antd' ;
5+ import { Flex , Form , Input , Select , Typography } from 'antd' ;
66import styled from 'styled-components' ;
77import { File , WorkflowType } from './types' ;
88import { useFetchModels } from '../../api/api' ;
99import { MODEL_PROVIDER_LABELS } from './constants' ;
1010import { ModelProviders , ModelProvidersDropdownOpts } from './types' ;
11- import { getWizardModel , getWizardModeType , useWizardCtx } from './utils' ;
11+ import { getWizardModeType , useWizardCtx } from './utils' ;
1212import FileSelectorButton from './FileSelectorButton' ;
1313import UseCaseSelector from './UseCaseSelector' ;
1414import { useLocation , useParams } from 'react-router-dom' ;
1515import { WizardModeType } from '../../types' ;
16+ import get from 'lodash/get' ;
17+ import forEach from 'lodash/forEach' ;
18+ import { useModelProviders } from '../Settings/hooks' ;
19+ import { ModelProviderType } from '../Settings/AddModelProviderButton' ;
20+ import { CustomModel } from '../Settings/SettingsPage' ;
21+ import filter from 'lodash/filter' ;
1622
1723
1824const StepContainer = styled ( Flex ) `
@@ -47,14 +53,21 @@ export const WORKFLOW_OPTIONS = [
4753export const MODEL_TYPE_OPTIONS : ModelProvidersDropdownOpts = [
4854 { label : MODEL_PROVIDER_LABELS [ ModelProviders . BEDROCK ] , value : ModelProviders . BEDROCK } ,
4955 { label : MODEL_PROVIDER_LABELS [ ModelProviders . CAII ] , value : ModelProviders . CAII } ,
56+ { label : MODEL_PROVIDER_LABELS [ ModelProviders . OPENAI ] , value : ModelProviders . OPENAI } ,
57+ { label : MODEL_PROVIDER_LABELS [ ModelProviders . GEMINI ] , value : ModelProviders . GEMINI } ,
5058] ;
5159
5260const Configure : FunctionComponent = ( ) => {
5361 const form = Form . useFormInstance ( ) ;
5462 const formData = Form . useWatch ( ( values ) => values , form ) ;
5563 const location = useLocation ( ) ;
5664 const { template_name, generate_file_name } = useParams ( ) ;
65+ const [ models , setModels ] = useState < string [ ] > ( [ ] )
5766 const [ wizardModeType , setWizardModeType ] = useState ( getWizardModeType ( location ) ) ;
67+ const { data } = useFetchModels ( ) ;
68+ const customModelPrividersReq = useModelProviders ( ) ;
69+ const customModels = get ( customModelPrividersReq , 'data.endpoints' , [ ] ) ;
70+ console . log ( 'customModels' , customModels ) ;
5871
5972 useEffect ( ( ) => {
6073 if ( wizardModeType === WizardModeType . DATA_AUGMENTATION ) {
@@ -77,10 +90,18 @@ const Configure: FunctionComponent = () => {
7790 }
7891 } , [ template_name ] ) ;
7992
93+ useEffect ( ( ) => {
94+ // set model providers
95+ // set model ids
96+ if ( formData && ( formData ?. inference_type === ModelProviderType . OPENAI || formData ?. inference_type === ModelProviderType . GEMINI ) && isEmpty ( generate_file_name ) ) {
97+ form . setFieldValue ( 'inference_type' , ModelProviders . OPENAI ) ;
98+ }
99+
100+ } , [ customModels , formData ] ) ;
101+
80102
81103 // let formData = Form.useWatch((values) => values, form);
82104 const { setIsStepValid } = useWizardCtx ( ) ;
83- const { data } = useFetchModels ( ) ;
84105 const [ selectedFiles , setSelectedFiles ] = useState (
85106 ! isEmpty ( form . getFieldValue ( 'doc_paths' ) ) ? form . getFieldValue ( 'doc_paths' ) : [ ] ) ;
86107
@@ -104,7 +125,6 @@ const Configure: FunctionComponent = () => {
104125
105126
106127 useEffect ( ( ) => {
107- console . log ( 'useEffect 1' ) ;
108128 if ( formData && formData ?. inference_type === undefined && isEmpty ( generate_file_name ) ) {
109129 form . setFieldValue ( 'inference_type' , ModelProviders . CAII ) ;
110130 setTimeout ( ( ) => {
@@ -155,6 +175,20 @@ const Configure: FunctionComponent = () => {
155175 }
156176 }
157177
178+ const onModelProviderChange = ( value : string ) => {
179+ form . setFieldValue ( 'model_id' , undefined )
180+ console . log ( 'value' , value ) ;
181+ if ( ModelProviderType . OPENAI === value ) {
182+ const _models = filter ( customModels , ( model : CustomModel ) => model . provider_type === ModelProviderType . OPENAI ) ;
183+ setModels ( _models . map ( ( _model : CustomModel ) => _model . model_id ) ) ;
184+ } else if ( ModelProviderType . GEMINI === value ) {
185+ const _models = filter ( customModels , ( model : CustomModel ) => model . provider_type === ModelProviderType . GEMINI ) ;
186+ setModels ( _models . map ( ( _model : CustomModel ) => _model . model_id ) ) ;
187+ }
188+ }
189+ console . log ( 'models' , models ) ;
190+
191+
158192 return (
159193 < StepContainer justify = 'center' >
160194 < FormContainer vertical >
@@ -178,7 +212,7 @@ const Configure: FunctionComponent = () => {
178212 >
179213 < Select
180214
181- onChange = { ( ) => form . setFieldValue ( 'model_id' , undefined ) }
215+ onChange = { ( value : string ) => onModelProviderChange ( value ) }
182216 placeholder = { 'Select a model provider' }
183217 >
184218 { MODEL_TYPE_OPTIONS . map ( ( { label, value } , i ) =>
@@ -200,15 +234,22 @@ const Configure: FunctionComponent = () => {
200234 { formData ?. inference_type === ModelProviders . CAII ? (
201235 < Input placeholder = { 'Enter Cloudera AI Inference Model ID' } />
202236 ) : (
203- < Select placeholder = { 'Select a Model' } notFoundContent = { 'You must select a Model Provider before selecting a Model' } >
204- { ! isEmpty ( data ?. models ) && data ?. models [ ModelProviders . BEDROCK ] ?. map ( ( model , i ) =>
237+ < Select
238+ placeholder = { 'Select a Model' }
239+ notFoundContent = { 'You must select a Model Provider before selecting a Model' }
240+ >
241+ { formData ?. inference_type === ModelProviders . BEDROCK && data ?. models ?. [ ModelProviders . BEDROCK ] ?. map ( ( model , i ) => (
205242 < Select . Option key = { `${ model } -${ i } ` } value = { model } >
206243 { model }
207244 </ Select . Option >
208- ) }
245+ ) ) }
246+ { ( formData ?. inference_type === ModelProviders . OPENAI || formData ?. inference_type === ModelProviders . GEMINI ) && models ?. map ( ( model , i ) => (
247+ < Select . Option key = { `${ model } -${ i } ` } value = { model } >
248+ { model }
249+ </ Select . Option >
250+ ) ) }
209251 </ Select >
210252 ) }
211-
212253 </ Form . Item >
213254 { formData ?. inference_type === ModelProviders . CAII && (
214255 < >
0 commit comments