Skip to content
Snippets Groups Projects

Issue #3504309 by vasike: Add Embedding capabilities - extend GeminiProvider implementation.

Merged Issue #3504309 by vasike: Add Embedding capabilities - extend GeminiProvider implementation.
All threads resolved!
All threads resolved!
Files
2
@@ -9,6 +9,9 @@ use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatInterface;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInterface;
use Drupal\ai\OperationType\Embeddings\EmbeddingsOutput;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\gemini_provider\GeminiChatMessageIterator;
@@ -25,7 +28,7 @@ use Symfony\Component\Yaml\Yaml;
id: 'gemini',
label: new TranslatableMarkup('Gemini')
)]
class GeminiProvider extends AiProviderClientBase implements ChatInterface {
class GeminiProvider extends AiProviderClientBase implements ChatInterface, EmbeddingsInterface {
/**
* The Gemini Client.
@@ -67,6 +70,21 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface {
if (!empty($models['models'])) {
foreach ($models['models'] as $model) {
// Separate models by operation type.
switch ($operation_type) {
case 'embeddings':
if (!preg_match('/^(models\/)(.)*(embedding-)/i', trim($model['name']))) {
continue 2;
}
break;
// @todo We need to add other operation types here later.
default:
if (preg_match('/^(models\/)(.)*(embedding-)/i', trim($model['name']))) {
continue 2;
}
break;
}
$supported_models[$model['name']] = $model['displayName'];
}
}
@@ -98,7 +116,10 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface {
*/
public function getSupportedOperationTypes(): array {
// @todo We need to add other operation types here later.
return ['chat'];
return [
'chat',
'embeddings',
];
}
/**
@@ -126,6 +147,31 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface {
return $generalConfig;
}
/**
* {@inheritdoc}
*/
public function getSetupData(): array {
return [
'key_config_name' => 'api_key',
'default_models' => [
'chat' => 'models/gemini-1.5-pro',
'chat_with_image_vision' => 'models/gemini-1.5-pro',
'chat_with_complex_json' => 'models/gemini-1.5-pro',
'embeddings' => 'models/embedding-001',
],
];
}
/**
* {@inheritdoc}
*/
public function embeddingsVectorSize(string $model_id): int {
return match ($model_id) {
'models/embedding-001', 'models/text-embedding-004' => 768,
default => 0,
};
}
/**
* {@inheritdoc}
*/
@@ -155,6 +201,10 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface {
$message->setRole('model');
}
if ($message->getRole() == 'assistant') {
$message->setRole('user');
}
if (!in_array($message->getRole(), ['model', 'user'])) {
$error_message = sprintf('The role %s, is not supported by Gemini Provider.', $message->getRole());
throw new AiResponseErrorException($error_message);
@@ -290,15 +340,28 @@ class GeminiProvider extends AiProviderClientBase implements ChatInterface {
/**
* {@inheritdoc}
*/
public function getSetupData(): array {
return [
'key_config_name' => 'api_key',
'default_models' => [
'chat' => 'models/gemini-1.5-pro',
'chat_with_image_vision' => 'models/gemini-1.5-pro',
'chat_with_complex_json' => 'models/gemini-1.5-pro',
],
];
public function embeddings(string|EmbeddingsInput $input, string $model_id, array $tags = []): EmbeddingsOutput {
$this->loadClient();
// Normalize the input if needed.
if ($input instanceof EmbeddingsInput) {
$input = $input->getPrompt();
}
try {
$response = $this->client->embeddingModel($model_id)->embedContent($input);
}
catch (\Exception $e) {
// @todo Handle the exception properly.
throw $e;
}
return new EmbeddingsOutput($response->embedding->values, $response->toArray(), []);
}
/**
* {@inheritdoc}
*/
public function maxEmbeddingsInput($model_id = ''): int {
return 2048;
}
}
Loading