diff --git a/docs/bundles/ai-bundle.rst b/docs/bundles/ai-bundle.rst index 336d7c2e4..f2739a4e5 100644 --- a/docs/bundles/ai-bundle.rst +++ b/docs/bundles/ai-bundle.rst @@ -136,6 +136,28 @@ Advanced Example with Multiple Agents vectorizer: 'ai.vectorizer.mistral_embeddings' store: 'ai.store.memory.research' +Cached platform +--------------- + +Thanks to Symfony's Cache component, platforms can be decorated and use any cache adapter, +this platform allows to reduce network calls / resource comsumption: + +.. code-block:: yaml + + # config/packages/ai.yaml + ai: + platform: + openai: + api_key: '%env(OPENAI_API_KEY)%' + cache: + platform: 'ai.platform.openai' + service: 'cache.app' + + agent: + openai: + platform: 'ai.platform.cache.openai' + model: 'gpt-4o-mini' + Store Dependency Injection -------------------------- diff --git a/docs/components/platform.rst b/docs/components/platform.rst index 6bea67f72..164466771 100644 --- a/docs/components/platform.rst +++ b/docs/components/platform.rst @@ -374,6 +374,29 @@ which can be useful to speed up the processing:: echo $result->asText().PHP_EOL; } +Cached Platform Calls +--------------------- + +Thanks to Symfony's Cache component, platform's calls can be cached to reduce network/resources calls/consumption:: + + use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory; + use Symfony\AI\Platform\CachedPlatform; + use Symfony\AI\Platform\Message\Message; + use Symfony\AI\Platform\Message\MessageBag; + use Symfony\Component\Cache\Adapter\ArrayAdapter; + use Symfony\Component\Cache\Adapter\TagAwareAdapter; + + $platform = PlatformFactory::create($apiKey, eventDispatcher: $dispatcher); + $cachedPlatform = new CachedPlatform($platform, new TagAwareAdapter(new ArrayAdapter()); + + $firstResult = $cachedPlatform->invoke('gpt-4o-mini', new MessageBag(Message::ofUser('What is the capital of France?'))); + + echo $firstResult->getContent().\PHP_EOL; + + $secondResult = $cachedPlatform->invoke('gpt-4o-mini', new MessageBag(Message::ofUser('What is the capital of France?'))); + + echo $secondResult->getContent().\PHP_EOL; + Testing Tools ------------- diff --git a/examples/.env b/examples/.env index 5ec1b247d..987f1b775 100644 --- a/examples/.env +++ b/examples/.env @@ -16,7 +16,7 @@ VOYAGE_API_KEY= REPLICATE_API_KEY= # For using Ollama -OLLAMA_HOST_URL=http://localhost:11434 +OLLAMA_HOST_URL=http://127.0.0.1:11434 OLLAMA_LLM=llama3.2 OLLAMA_EMBEDDINGS=nomic-embed-text diff --git a/examples/misc/agent-with-cache.php b/examples/misc/agent-with-cache.php new file mode 100644 index 000000000..2594940af --- /dev/null +++ b/examples/misc/agent-with-cache.php @@ -0,0 +1,44 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory; +use Symfony\AI\Platform\CachedPlatform; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\Component\Cache\Adapter\ArrayAdapter; +use Symfony\Component\Cache\Adapter\TagAwareAdapter; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client()); +$cachedPlatform = new CachedPlatform($platform, new TagAwareAdapter(new ArrayAdapter())); + +$agent = new Agent($cachedPlatform, 'qwen3:0.6b-q4_K_M'); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'), +); +$result = $agent->call($messages, [ + 'prompt_cache_key' => 'chat', +]); + +assert($result->getMetadata()->has('cached')); + +echo $result->getContent().\PHP_EOL; + +$secondResult = $agent->call($messages, [ + 'prompt_cache_key' => 'chat', +]); + +assert($secondResult->getMetadata()->has('cached')); + +echo $secondResult->getContent().\PHP_EOL; diff --git a/src/ai-bundle/config/options.php b/src/ai-bundle/config/options.php index 635f764e8..088a0bc29 100644 --- a/src/ai-bundle/config/options.php +++ b/src/ai-bundle/config/options.php @@ -54,6 +54,16 @@ ->end() ->end() ->end() + ->arrayNode('cache') + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->stringNode('platform')->isRequired()->end() + ->stringNode('service')->isRequired()->end() + ->stringNode('cache_key')->end() + ->end() + ->end() + ->end() ->arrayNode('eleven_labs') ->children() ->stringNode('host')->end() @@ -130,6 +140,7 @@ ->defaultValue('http_client') ->info('Service ID of the HTTP client to use') ->end() + ->scalarNode('cache')->end() ->end() ->end() ->arrayNode('cerebras') diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 25d6a3ea1..66a2e4edd 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -55,6 +55,7 @@ use Symfony\AI\Platform\Bridge\Scaleway\PlatformFactory as ScalewayPlatformFactory; use Symfony\AI\Platform\Bridge\VertexAi\PlatformFactory as VertexAiPlatformFactory; use Symfony\AI\Platform\Bridge\Voyage\PlatformFactory as VoyagePlatformFactory; +use Symfony\AI\Platform\CachedPlatform; use Symfony\AI\Platform\Exception\RuntimeException; use Symfony\AI\Platform\Message\Content\File; use Symfony\AI\Platform\ModelClientInterface; @@ -293,6 +294,25 @@ private function processPlatformConfig(string $type, array $platform, ContainerB return; } + if ('cache' === $type) { + foreach ($platform as $name => $config) { + $definition = (new Definition(CachedPlatform::class)) + ->setDecoratedService($config['platform']) + ->setArguments([ + new Reference('.inner'), + new Reference($config['service']), + $config['cache_key'], + ]) + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->addTag('ai.platform', ['name' => 'cache']); + + $container->setDefinition('ai.platform.cache.'.$name, $definition); + } + + return; + } + if ('eleven_labs' === $type) { $platformId = 'ai.platform.eleven_labs'; $definition = (new Definition(Platform::class)) diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index ba1a461c3..3e093f905 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -2757,6 +2757,41 @@ public function testVectorizerModelBooleanOptionsArePreserved() $this->assertSame('text-embedding-3-small?normalize=false&cache=true&nested%5Bbool%5D=false', $vectorizerDefinition->getArgument(1)); } + public function testCachedPlatformCanBeUsed() + { + $container = $this->buildContainer([ + 'ai' => [ + 'platform' => [ + 'ollama' => [ + 'host_url' => 'http://127.0.0.1:11434', + ], + 'cache' => [ + 'ollama' => [ + 'platform' => 'ai.platform.ollama', + 'service' => 'cache.app', + 'cache_key' => 'ollama', + ], + ], + ], + ], + ]); + + $this->assertTrue($container->hasDefinition('ai.platform.cache.ollama')); + + $definition = $container->getDefinition('ai.platform.cache.ollama'); + $this->assertCount(3, $definition->getArguments()); + + $this->assertInstanceOf(Reference::class, $definition->getArgument(0)); + $platformArgument = $definition->getArgument(0); + $this->assertSame('.inner', (string) $platformArgument); + + $this->assertInstanceOf(Reference::class, $definition->getArgument(1)); + $cacheArgument = $definition->getArgument(1); + $this->assertSame('cache.app', (string) $cacheArgument); + + $this->assertSame('ollama', $definition->getArgument(2)); + } + private function buildContainer(array $configuration): ContainerBuilder { $container = new ContainerBuilder(); @@ -2795,6 +2830,13 @@ private function getFullConfig(): array 'api_version' => '2024-02-15-preview', ], ], + 'cache' => [ + 'azure' => [ + 'platform' => 'ai.platform.azure', + 'service' => 'cache.app', + 'cache_key' => 'foo', + ], + ], 'eleven_labs' => [ 'host' => 'https://api.elevenlabs.io/v1', 'api_key' => 'eleven_labs_key_full', diff --git a/src/platform/composer.json b/src/platform/composer.json index d4f13b86e..b01424e79 100644 --- a/src/platform/composer.json +++ b/src/platform/composer.json @@ -65,12 +65,16 @@ "phpstan/phpstan-symfony": "^2.0.6", "phpunit/phpunit": "^11.5", "symfony/ai-agent": "@dev", + "symfony/cache": "^7.3|^8.0", "symfony/console": "^7.3|^8.0", "symfony/dotenv": "^7.3|^8.0", "symfony/finder": "^7.3|^8.0", "symfony/process": "^7.3|^8.0", "symfony/var-dumper": "^7.3|^8.0" }, + "suggest": { + "symfony/cache": "Enable caching for platforms" + }, "autoload": { "psr-4": { "Symfony\\AI\\Platform\\": "src/" diff --git a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php index cf62c61e1..dd76d6d87 100644 --- a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php +++ b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php @@ -43,13 +43,14 @@ public function convert(RawResultInterface $result, array $options = []): Result return \array_key_exists('embeddings', $data) ? $this->doConvertEmbeddings($data) - : $this->doConvertCompletion($data); + : $this->doConvertCompletion($data, $options); } /** * @param array $data + * @param array $options */ - public function doConvertCompletion(array $data): ResultInterface + public function doConvertCompletion(array $data, array $options): ResultInterface { if (!isset($data['message'])) { throw new RuntimeException('Response does not contain message.'); @@ -69,7 +70,19 @@ public function doConvertCompletion(array $data): ResultInterface return new ToolCallResult(...$toolCalls); } - return new TextResult($data['message']['content']); + $result = new TextResult($data['message']['content']); + + if (\array_key_exists('prompt_cache_key', $options)) { + $metadata = $result->getMetadata(); + + $metadata->add('cached', true); + $metadata->add('prompt_cache_key', $options['prompt_cache_key']); + $metadata->add('cached_prompt_count', $data['prompt_eval_count']); + $metadata->add('cached_completion_count', $data['eval_count']); + $metadata->add('cached_time', (new \DateTimeImmutable())->getTimestamp()); + } + + return $result; } /** diff --git a/src/platform/src/CachedPlatform.php b/src/platform/src/CachedPlatform.php new file mode 100644 index 000000000..36d2c56be --- /dev/null +++ b/src/platform/src/CachedPlatform.php @@ -0,0 +1,69 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform; + +use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; +use Symfony\AI\Platform\Result\DeferredResult; +use Symfony\Component\Cache\Adapter\TagAwareAdapterInterface; +use Symfony\Contracts\Cache\CacheInterface; +use Symfony\Contracts\Cache\ItemInterface; + +/** + * @author Guillaume Loulier + */ +final class CachedPlatform implements PlatformInterface +{ + public function __construct( + private readonly PlatformInterface $platform, + private readonly (CacheInterface&TagAwareAdapterInterface)|null $cache = null, + private readonly ?string $cacheKey = null, + ) { + } + + public function invoke(string $model, array|string|object $input, array $options = []): DeferredResult + { + $invokeCall = fn (string $model, array|string|object $input, array $options = []): DeferredResult => $this->platform->invoke($model, $input, $options); + + if ($this->cache instanceof CacheInterface && (\array_key_exists('prompt_cache_key', $options) && '' !== $options['prompt_cache_key'])) { + $cacheKey = \sprintf('%s_%s', $this->cacheKey ?? $options['prompt_cache_key'], md5($model)); + + unset($options['prompt_cache_key']); + + return $this->cache->get($cacheKey, static function (ItemInterface $item) use ($invokeCall, $model, $input, $options, $cacheKey): DeferredResult { + $item->tag($model); + + $result = $invokeCall($model, $input, $options); + + $result = new DeferredResult( + $result->getResultConverter(), + $result->getRawResult(), + $options, + ); + + $result->getMetadata()->set([ + 'cached' => true, + 'cache_key' => $cacheKey, + 'cached_at' => (new \DateTimeImmutable())->getTimestamp(), + ]); + + return $result; + }); + } + + return $invokeCall($model, $input, $options); + } + + public function getModelCatalog(): ModelCatalogInterface + { + return $this->platform->getModelCatalog(); + } +} diff --git a/src/platform/src/Result/DeferredResult.php b/src/platform/src/Result/DeferredResult.php index 2448446a8..de5bbce7f 100644 --- a/src/platform/src/Result/DeferredResult.php +++ b/src/platform/src/Result/DeferredResult.php @@ -13,6 +13,7 @@ use Symfony\AI\Platform\Exception\ExceptionInterface; use Symfony\AI\Platform\Exception\UnexpectedResultTypeException; +use Symfony\AI\Platform\Metadata\MetadataAwareTrait; use Symfony\AI\Platform\ResultConverterInterface; use Symfony\AI\Platform\Vector\Vector; @@ -21,6 +22,8 @@ */ final class DeferredResult { + use MetadataAwareTrait; + private bool $isConverted = false; private ResultInterface $convertedResult; @@ -50,6 +53,8 @@ public function getResult(): ResultInterface $this->isConverted = true; } + $this->convertedResult->getMetadata()->set($this->getMetadata()->all()); + return $this->convertedResult; } diff --git a/src/platform/tests/CachedPlatformTest.php b/src/platform/tests/CachedPlatformTest.php new file mode 100644 index 000000000..934334404 --- /dev/null +++ b/src/platform/tests/CachedPlatformTest.php @@ -0,0 +1,58 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests; + +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\CachedPlatform; +use Symfony\AI\Platform\PlatformInterface; +use Symfony\AI\Platform\Result\DeferredResult; +use Symfony\AI\Platform\Result\RawHttpResult; +use Symfony\AI\Platform\Result\TextResult; +use Symfony\AI\Platform\ResultConverterInterface; +use Symfony\Component\Cache\Adapter\ArrayAdapter; +use Symfony\Component\Cache\Adapter\TagAwareAdapter; +use Symfony\Contracts\HttpClient\ResponseInterface as SymfonyHttpResponse; + +final class CachedPlatformTest extends TestCase +{ + public function testPlatformCanReturnCachedResultWhenCalledTwice() + { + $httpResponse = $this->createStub(SymfonyHttpResponse::class); + $rawHttpResult = new RawHttpResult($httpResponse); + + $resultConverter = self::createMock(ResultConverterInterface::class); + $resultConverter->expects($this->once()) + ->method('convert') + ->with($rawHttpResult, []) + ->willReturn(new TextResult('test content')); + + $platform = $this->createMock(PlatformInterface::class); + $platform->expects($this->once())->method('invoke')->willReturn(new DeferredResult($resultConverter, $rawHttpResult)); + + $cachedPlatform = new CachedPlatform( + $platform, + new TagAwareAdapter(new ArrayAdapter()), + ); + + $deferredResult = $cachedPlatform->invoke('foo', 'bar', [ + 'prompt_cache_key' => 'symfony', + ]); + + $this->assertSame('test content', $deferredResult->getResult()->getContent()); + + $secondDeferredResult = $cachedPlatform->invoke('foo', 'bar', [ + 'prompt_cache_key' => 'symfony', + ]); + + $this->assertSame('test content', $secondDeferredResult->getResult()->getContent()); + } +}