diff --git a/README.md b/README.md index 399f3dd8ecd39fa0f2471115cbc0e21758f525e7..5dc73139214a5408106e5bd0f76e55df77afc4a7 100644 --- a/README.md +++ b/README.md @@ -10,13 +10,14 @@ abstraction layer. ### Features - **Chat Integration**: Implement ChatInterface (text and images to text). +- **Embeddings Integration**: Implement EmbeddingsInterface (text to text embeddings). - **Configuration**: Provides a configuration form for Gemini authentication. - **Plugin Implementation**: Offers a plugin implementation for the Drupal AI module. As Google's Gemini is multimodal thing, we need to implement other interfaces too -(text to speech, embeddings etc.) +(text to speech, text to image, etc.) ## Requirements diff --git a/src/Plugin/AiProvider/GeminiProvider.php b/src/Plugin/AiProvider/GeminiProvider.php index d77c4066b7ed3206159a3811bbedba674c9ccde2..8cb81cd32642286fae24150c52cae2286afb6e4c 100644 --- a/src/Plugin/AiProvider/GeminiProvider.php +++ b/src/Plugin/AiProvider/GeminiProvider.php @@ -9,6 +9,9 @@ use Drupal\ai\OperationType\Chat\ChatInput; use Drupal\ai\OperationType\Chat\ChatInterface; use Drupal\ai\OperationType\Chat\ChatMessage; use Drupal\ai\OperationType\Chat\ChatOutput; +use Drupal\ai\OperationType\Embeddings\EmbeddingsInput; +use Drupal\ai\OperationType\Embeddings\EmbeddingsInterface; +use Drupal\ai\OperationType\Embeddings\EmbeddingsOutput; use Drupal\Core\Config\ImmutableConfig; use Drupal\Core\StringTranslation\TranslatableMarkup; use Drupal\gemini_provider\GeminiChatMessageIterator; @@ -25,7 +28,7 @@ use Symfony\Component\Yaml\Yaml; id: 'gemini', label: new TranslatableMarkup('Gemini') )] -class GeminiProvider extends AiProviderClientBase implements ChatInterface { +class GeminiProvider extends AiProviderClientBase implements ChatInterface, EmbeddingsInterface { /** * The Gemini Client. @@ -67,6 +70,21 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface { if (!empty($models['models'])) { foreach ($models['models'] as $model) { + // Separate models by operation type. + switch ($operation_type) { + case 'embeddings': + if (!preg_match('/^(models\/)(.)*(embedding-)/i', trim($model['name']))) { + continue 2; + } + break; + + // @todo We need to add other operation types here later. + default: + if (preg_match('/^(models\/)(.)*(embedding-)/i', trim($model['name']))) { + continue 2; + } + break; + } $supported_models[$model['name']] = $model['displayName']; } } @@ -98,7 +116,10 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface { */ public function getSupportedOperationTypes(): array { // @todo We need to add other operation types here later. - return ['chat']; + return [ + 'chat', + 'embeddings', + ]; } /** @@ -126,6 +147,31 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface { return $generalConfig; } + /** + * {@inheritdoc} + */ + public function getSetupData(): array { + return [ + 'key_config_name' => 'api_key', + 'default_models' => [ + 'chat' => 'models/gemini-1.5-pro', + 'chat_with_image_vision' => 'models/gemini-1.5-pro', + 'chat_with_complex_json' => 'models/gemini-1.5-pro', + 'embeddings' => 'models/embedding-001', + ], + ]; + } + + /** + * {@inheritdoc} + */ + public function embeddingsVectorSize(string $model_id): int { + return match ($model_id) { + 'models/embedding-001', 'models/text-embedding-004' => 768, + default => 0, + }; + } + /** * {@inheritdoc} */ @@ -155,6 +201,10 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface { $message->setRole('model'); } + if ($message->getRole() == 'assistant') { + $message->setRole('user'); + } + if (!in_array($message->getRole(), ['model', 'user'])) { $error_message = sprintf('The role %s, is not supported by Gemini Provider.', $message->getRole()); throw new AiResponseErrorException($error_message); @@ -290,15 +340,28 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface { /** * {@inheritdoc} */ - public function getSetupData(): array { - return [ - 'key_config_name' => 'api_key', - 'default_models' => [ - 'chat' => 'models/gemini-1.5-pro', - 'chat_with_image_vision' => 'models/gemini-1.5-pro', - 'chat_with_complex_json' => 'models/gemini-1.5-pro', - ], - ]; + public function embeddings(string|EmbeddingsInput $input, string $model_id, array $tags = []): EmbeddingsOutput { + $this->loadClient(); + // Normalize the input if needed. + if ($input instanceof EmbeddingsInput) { + $input = $input->getPrompt(); + } + try { + $response = $this->client->embeddingModel($model_id)->embedContent($input); + } + catch (\Exception $e) { + // @todo Handle the exception properly. + throw $e; + } + + return new EmbeddingsOutput($response->embedding->values, $response->toArray(), []); + } + + /** + * {@inheritdoc} + */ + public function maxEmbeddingsInput($model_id = ''): int { + return 2048; } }