Skip to content
Snippets Groups Projects
Commit 7368ebd4 authored by Marcus Johansson's avatar Marcus Johansson
Browse files

Milvus fixes to have first version available

parent bf3f1254
No related branches found
No related tags found
No related merge requests found
<?php
namespace Drupal\search_api_ai_milvus\Exception;
/**
* This happens when Milvus can't do something.
*/
class MilvusNotPossibleException extends \Exception {
}
......@@ -6,11 +6,13 @@ use Drupal\Core\Form\FormStateInterface;
use Drupal\Core\Form\SubformStateInterface;
use Drupal\Core\Plugin\PluginFormInterface;
use Drupal\search_api\IndexInterface;
use Drupal\search_api\Item\FieldInterface;
use Drupal\search_api\Item\ItemInterface;
use Drupal\search_api\Plugin\PluginFormTrait;
use Drupal\search_api\Query\QueryInterface;
use Drupal\search_api_ai\Backend\SearchApiAiBackendPluginBase;
use Drupal\search_api_ai\SearchApiAiBackendInterface;
use Drupal\search_api_ai_milvus\Exception\MilvusNotPossibleException;
use HelgeSverre\Milvus\Milvus;
/**
......@@ -224,12 +226,18 @@ class SearchApiMilvusBackend extends SearchApiAiBackendPluginBase implements Plu
}
}
else {
$itemBase['metadata'][$field->getFieldIdentifier()] = is_array($field->getValues()) ? implode(',', $field->getValues()) : $field->getValues();
// We need to figure out cardinality if we should store as array.
if ($this->isMultiple($field)) {
$itemBase['metadata'][$field->getFieldIdentifier()] = is_array($field->getValues()) ? $field->getValues() : [$field->getValues()];
}
else {
$itemBase['metadata'][$field->getFieldIdentifier()] = $field->getValues();
}
}
}
foreach ($chunkedItems as $chunkedItem) {
$chunkedItem += $itemBase;
$chunkedItem = array_merge_recursive($chunkedItem, $itemBase);
$data['drupal_long_id'] = $chunkedItem['id'];
$data['drupal_entity_id'] = $item->getId();
$data['vector'] = $chunkedItem['values'];
......@@ -283,13 +291,53 @@ class SearchApiMilvusBackend extends SearchApiAiBackendPluginBase implements Plu
if ($query->hasTag('server_index_status')) {
return NULL;
}
$condition_group = $query->getConditionGroup();
$filters = [];
foreach ($condition_group->getConditions() as $condition) {
$fieldData = $index->getField($condition->getField());
$fieldType = $fieldData->getType();
$isMultiple = $this->isMultiple($fieldData);
// Normalize the value.
$values = is_array($condition->getValue()) ? $condition->getValue() : [$condition->getValue()];
$normalizedValues = "";
if (in_array($fieldType, ['string', 'full_text'])) {
$normalizedValues = '"' . implode('","', $values) . '"';
}
else {
$normalizedValues = implode(',', $values);
}
if ($isMultiple) {
if (in_array($condition->getOperator(), [
'=',
'IN',
])) {
$filters[] = '(ARRAY_CONTAINS(' . $condition->getField() . ', ' . $normalizedValues . '))';
}
else {
throw new MilvusNotPossibleException('Milvus does not support negative operator on multiple fields.');
}
}
else {
$filters[] = '(' . $condition->getField() . $condition->getOperator() . $normalizedValues . ')';
}
}
// Warn if they try sorting in Milvus.
if (!empty($query->getSorts())) {
\Drupal::messenger()->addWarning('Milvus does not support sorting.');
}
// If no vector is set, we return a normal query.
if (empty($query->getOption('query_embedding'))) {
// If filters is empty, we need to set it to something.
if (empty($filters)) {
$filters[] = 'id not in [0]';
}
$response = $this->getClient()->vector()->query(
collectionName: $this->configuration['collection'],
limit: $query->getOption('limit', 10),
offset: $query->getOption('offset', 0),
dbName: $this->configuration['database'],
filter: 'id not in [0]',
filter: implode(' AND ', $filters),
outputFields: ['id', 'drupal_entity_id', 'drupal_long_id', 'content'],
);
}
......@@ -297,7 +345,9 @@ class SearchApiMilvusBackend extends SearchApiAiBackendPluginBase implements Plu
$response = $this->getClient()->vector()->search(
vector: $query->getOption('query_embedding') ?? [],
limit: $query->getOption('limit', 10),
offset: $query->getOption('offset', 0),
collectionName: $this->configuration['collection'],
filter: implode(' AND ', $filters),
dbName: $this->configuration['database'],
outputFields: ['id', 'drupal_entity_id', 'drupal_long_id', 'content'],
);
......@@ -435,4 +485,28 @@ class SearchApiMilvusBackend extends SearchApiAiBackendPluginBase implements Plu
$this->apiKey = $apiKey;
}
/**
* Figure out cardinality from field item.
*
* @param \Drupal\search_api\Item\FieldInterface $field
* The field.
*
* @return bool
* If the cardinality is multiple or not.
*/
public function isMultiple(FieldInterface $field): bool {
[$fieldName] = explode(':', $field->getPropertyPath());
[$type, $entity_type] = explode(':', $field->getDatasourceId());
// From the fieldname and entity type, we can figure out the cardinality.
$fields = \Drupal::service('entity_field.manager')->getFieldStorageDefinitions($entity_type);
foreach ($fields as $field) {
if ($field->getName() === $fieldName) {
$cardinality = $field->getCardinality();
return $cardinality === 1 ? FALSE : TRUE;
}
}
return TRUE;
}
}
......@@ -4,5 +4,5 @@ description: 'Provides AI integrations using Search API as a backend.'
package: Search
core_version_requirement: ^9 || ^10
dependencies:
- openai:openai
- ai:ai
- search_api:search_api
......@@ -2,6 +2,6 @@ search_api_ai.test:
path: '/api-ai-test'
defaults:
_title: 'Test'
_form: '\Drupal\search_api_ai\Form\ChatForm'
_controller: '\Drupal\search_api_ai\Controller\QueryTest::runQuery'
requirements:
_access: 'TRUE'
......@@ -2,11 +2,10 @@ services:
plugin.manager.embedding_engine:
class: Drupal\search_api_ai\EmbeddingEnginePluginManager
parent: default_plugin_manager
search_api_ai.embedding_engine_static:
class: Drupal\search_api_ai\EmbeddingEngineStatic
arguments: ['@plugin.manager.embedding_engine']
search_api_ai.server_engine_static:
class: Drupal\search_api_ai\ServerConfigurationStatic
search_api_ai.event_listener:
class: Drupal\search_api_ai\EventListener\SearchApiEventListener
arguments: ['@search_api_ai.embedding_engine_static']
arguments: ['@search_api_ai.server_engine_static']
tags:
- { name: event_subscriber }
<?php
namespace Drupal\search_api_ai\Controller;
use Drupal\Core\Controller\ControllerBase;
use Drupal\search_api\Query\QueryInterface;
use Symfony\Component\HttpFoundation\Request;
/**
* Class QueryTest.
*/
class QueryTest extends ControllerBase {
/**
* Ai Provider.
*
* @var \Drupal\ai\AiProviderPluginManager
*/
protected $aiProvider;
/**
* Entity type manager.
*
* @var \Drupal\entity\EntityTypeManagerInterface
*/
protected $entityTypeManager;
/**
* Constructor.
*/
public function __construct() {
$this->aiProvider = \Drupal::service('ai.provider');
$this->entityTypeManager = \Drupal::entityTypeManager();
}
/**
* Test the queries.
*
* @return array
* Return markup.
*/
public function runQuery(Request $request) {
$index_id = $request->query->get('index');
$search_query = $request->query->get('query');
if (!$index_id || !$search_query) {
return [
'#type' => 'markup',
'#markup' => $this->t('Please provide server and query parameters.')
];
}
/** @var \Drupal\search_api\Entity\Index @index */
$index = $this->entityTypeManager->getStorage('search_api_index')->load($index_id);
$server = $index->getServerInstance();
/** @var \Drupal\search_api_ai\Backend\SearchApiAiBackendPluginBase */
$backend = $server->getBackend();
[$provider_id, $model_id] = explode('__', $backend->getConfiguration()['embeddings_engine']);
$vectors = $this->aiProvider->createInstance($provider_id)->embeddings($search_query, $model_id)->getNormalized();
// Vector Search.
$query = $index->query();
// The actual vector search.
$query->setOption('query_embedding', $vectors);
// Normal search parameters.
$query->addCondition('name', 'Navy', '=');
// Limit the results.
$query->range(0, 10);
// This will create a warning, since only vector sorting is allowed.
$query->sort('name');
$results = $query->execute();
$output = "Results: <br>";
foreach ($results as $match) {
$output .= '<h3>' . $match->getExtraData('drupal_entity_id') . '</h3>';
$output .= '<b>Score:</b> ' . $match->getScore() . '<br>';
$output .= '<pre>' . print_r($match->getExtraData('content'), TRUE) . '</pre>';
$output .= '<hr>';
}
$server->search($query);
return [
'#type' => 'markup',
'#markup' => $output,
];
}
}
......@@ -4,8 +4,8 @@ namespace Drupal\search_api_ai\EventListener;
use Drupal\search_api\Event\IndexingItemsEvent;
use Drupal\search_api\Event\SearchApiEvents;
use Drupal\search_api_ai\EmbeddingEngineStatic;
use Drupal\search_api_ai\SearchApiAiBackendInterface;
use Drupal\search_api_ai\ServerConfigurationStatic;
use Symfony\Component\EventDispatcher\EventSubscriberInterface;
/**
......@@ -16,20 +16,20 @@ use Symfony\Component\EventDispatcher\EventSubscriberInterface;
class SearchApiEventListener implements EventSubscriberInterface {
/**
* The embeddings engine.
* The server configuration.
*
* @var \Drupal\search_api_ai\EmbeddingEngineStatic
* @var \Drupal\search_api_ai\ServerConfigurationStatic
*/
protected $embedEngine;
protected $serverConfig;
/**
* Constructs a new class instance.
*
* @param \Drupal\search_api_ai\EmbeddingEngineStatic $embed_engine_static
* The embeddings engine.
* @param \Drupal\search_api_ai\ServerConfigurationStatic $embed_engine_static
* The server configuration.
*/
public function __construct(EmbeddingEngineStatic $embed_engine_static) {
$this->embedEngine = $embed_engine_static;
public function __construct(ServerConfigurationStatic $server_config_static) {
$this->serverConfig = $server_config_static;
}
/**
......@@ -54,14 +54,7 @@ class SearchApiEventListener implements EventSubscriberInterface {
if (!($server instanceof SearchApiAiBackendInterface)) {
return;
}
/** @var \Drupal\search_api_ai\Backend\SearchApiAiBackendPluginBase $server */
$server->setEngineConfiguration($server->getConfiguration());
$embeddings_engine = $server->loadEmbeddingsEngine();
if ($embeddings_engine) {
// Because we have no context during indexing, we need to store the engine
// in a static class, so that the Embeddings data type can react to it.
$this->embedEngine->setEmbeddingEngine($embeddings_engine);
}
$config = $server->getConfiguration();
$this->serverConfig->setConfiguration($config);
}
}
......@@ -2,9 +2,11 @@
namespace Drupal\search_api_ai\Plugin\search_api\data_type;
use Drupal\ai\AiProviderPluginManager;
use Drupal\openai\Utility\StringHelper;
use Drupal\search_api\DataType\DataTypePluginBase;
use Drupal\search_api_ai\EmbeddingEngineStatic;
use Drupal\search_api_ai\ServerConfigurationStatic;
use Drupal\search_api_ai\TextChunker;
use Symfony\Component\DependencyInjection\ContainerInterface;
......@@ -20,26 +22,33 @@ use Symfony\Component\DependencyInjection\ContainerInterface;
class Embeddings extends DataTypePluginBase {
/**
* Embeddings engine.
* AI Plugin manager.
*
* @var \Drupal\search_api_ai\EmbeddingEngineInterface
* The embeddings engine.
* @var \Drupal\ai\AiProviderPluginManager
*/
protected $embeddingsEngine;
protected $aiPluginManager;
/**
* The configurations manager.
*
* @var array
*/
protected $configuration = [];
/**
* Constructor.
*/
public function __construct(array $configuration, $plugin_id, $plugin_definition, EmbeddingEngineStatic $embeddingEngineStatic) {
public function __construct(array $configuration, $plugin_id, $plugin_definition, AiProviderPluginManager $ai_plugin_manager) {
parent::__construct($configuration, $plugin_id, $plugin_definition);
// Check if the embeddings engine is in the configuration.
if (isset($configuration['embeddings_engine'])) {
$this->embeddingsEngine = $configuration['embeddings_engine'];
if (!empty($configuration)) {
$this->configuration = $configuration;
}
else {
// Otherwise load it from static.
$this->embeddingsEngine = $embeddingEngineStatic->getEmbeddingEngine();
$this->configuration = ServerConfigurationStatic::getConfiguration();
}
$this->aiPluginManager = $ai_plugin_manager;
}
/**
......@@ -50,7 +59,7 @@ class Embeddings extends DataTypePluginBase {
$configuration,
$plugin_id,
$plugin_definition,
$container->get('search_api_ai.embedding_engine_static')
$container->get('ai.provider')
);
}
......@@ -58,7 +67,10 @@ class Embeddings extends DataTypePluginBase {
* {@inheritdoc}
*/
public function getValue($value) {
$chunkMaxSize = $this->embeddingsEngine->getDimension();
if (empty($this->configuration['embeddings_engine']) || empty($this->configuration['embeddings_engine_configuration']['dimension'])) {
throw new \Exception('Embeddings engine not configured.');
}
$chunkMaxSize = $this->configuration['embeddings_engine_configuration']['dimension'];
$chunkMinOverlap = 64;
$chunks = TextChunker::chunkText($value, $chunkMaxSize, $chunkMinOverlap);
......@@ -70,13 +82,15 @@ class Embeddings extends DataTypePluginBase {
continue;
}
$text = StringHelper::prepareText($chunk, [], $chunkMaxSize);
$vectors = $this->embeddingsEngine->generateEmbeddings($text);
$parts = explode('__', $this->configuration['embeddings_engine']);
if (count($parts) !== 2) {
throw new \Exception('Invalid embeddings engine configuration.');
}
$vectors = $this->aiPluginManager->createInstance($parts[0])->embeddings($chunk, $parts[1])->getNormalized();
if (is_array($vectors)) {
$items[$delta] = [
'content' => $text,
'content' => $chunk,
'vectors' => $vectors,
];
}
......
......@@ -6,36 +6,36 @@ namespace Drupal\search_api_ai;
* Static context for the embedding engines.
*
* Since the data type plugins have no context of where they are being used,
* we need to use a static class to temporarily store the embed engine during
* we need to use a static class to temporarily store the server config during
* indexing.
*/
class EmbeddingEngineStatic {
class ServerConfigurationStatic {
/**
* The embedding engine object.
* The server configuration.
*
* @var mixed
* @var array
*/
protected $embeddingEngine;
protected static $serverConfiguration = [];
/**
* Sets the embeddinge engine object.
* Sets the server configuration.
*
* @param mixed $object
* The embedding engine object.
* @param mixed $server_config
* The server configuration.
*/
public function setEmbeddingEngine($object) {
$this->embeddingEngine = $object;
public static function setConfiguration($server_config) {
self::$serverConfiguration = $server_config;
}
/**
* Gets the embedding engine object.
* Gets the server configuration.
*
* @return mixed
* The embedding engine object.
* The server configuration.
*/
public function getEmbeddingEngine() {
return $this->embeddingEngine;
public static function getConfiguration() {
return self::$serverConfiguration;
}
}
......@@ -96,17 +96,18 @@ trait SearchApiAiBackendTrait {
'#type' => 'number',
'#title' => $this->t('Dimensions'),
'#description' => $this->t('The number of dimensions for the embeddings.'),
'#default_value' => $this->traitConfiguration['embeddings_engine_configuration']['dimension'],
'#default_value' => $this->traitConfiguration['embeddings_engine_configuration']['dimension'] ?? '',
'#required' => TRUE,
'#disabled' => TRUE,
];
// If the embeddings engine is set, add the configuration form.
if (!empty($this->traitConfiguration['embeddings_engine']) || $form_state->get('embeddings_engine')) {
$plugin_manager = \Drupal::service('plugin.manager.embedding_engine');
$rule = $plugin_manager->createInstance($this->traitConfiguration['embeddings_engine'] ?? $form_state->get('embeddings_engine'));
foreach ($rule->buildEmbeddingConfigurationForm() as $key => $value) {
$form['embeddings_engine_configuration'][$key] = $value;
$plugin_manager = \Drupal::service('ai.provider');
$parts = explode('__', $this->traitConfiguration['embeddings_engine'] ?? $form_state->get('embeddings_engine'));
$rule = $plugin_manager->createInstance($parts[0])->getAvailableConfiguration('embeddings', $parts[1]);
foreach ($rule as $key => $value) {
$form['embeddings_engine_configuration'][$key]['#default_value'] = $value['default'];
}
}
......@@ -116,12 +117,12 @@ trait SearchApiAiBackendTrait {
/**
* Load the embeddings engine with a configuration.
*
* @return \Drupal\search_api_ai\EmbeddingEngineInterface
* The embeddings engine.
* @return \Drupal\ai\AiProviderInterface
*/
public function loadEmbeddingsEngine() {
$plugin_manager = \Drupal::service('plugin.manager.embedding_engine');
return $plugin_manager->createInstance($this->traitConfiguration['embeddings_engine'], $this->traitConfiguration['embeddings_engine_configuration']);
$plugin_manager = \Drupal::service('ai.provider');
$parts = explode('__', $this->traitConfiguration['embeddings_engine']);
return $plugin_manager->createInstance($parts[0]);
}
/**
......@@ -142,11 +143,11 @@ trait SearchApiAiBackendTrait {
*/
public function getEmbeddingEnginesOptions(): array {
$options = [];
$plugin_manager = \Drupal::service('plugin.manager.embedding_engine');
foreach ($plugin_manager->getDefinitions() as $id => $definition) {
$rule = $plugin_manager->createInstance($id);
if ($rule->isAvailable()) {
$options[$id] = $definition['label'];
$plugin_manager = \Drupal::service('ai.provider');
foreach ($plugin_manager->getProvidersForOperationType('embeddings') as $id => $definition) {
$provider = $plugin_manager->createInstance($id);
foreach ($provider->getConfiguredModels('embeddings') as $model => $label) {
$options[$id . '__' . $model] = $label;
}
}
// Send a warning message if there are no available embedding engines.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment