Skip to content
Snippets Groups Projects
Commit 9d9bb920 authored by Marcus Johansson's avatar Marcus Johansson
Browse files

Resolve #3489690 "Add an event"

parent 6c79244c
No related branches found
No related tags found
1 merge request!295Resolve #3489690 "Add an event"
Pipeline #384650 passed
Showing
with 535 additions and 29 deletions
......@@ -2,7 +2,7 @@ services:
ai.provider:
class: Drupal\ai\AiProviderPluginManager
parent: default_plugin_manager
arguments: ['@service_container', '@messenger']
arguments: ['@service_container', '@messenger', '@uuid']
ai.form_helper:
class: Drupal\ai\Service\AiProviderFormHelper
arguments: ['@ai.provider', '@path.current']
......
# Events
There are two important events that are currently available in the AI module. One that is triggered before the request is being sent and one that is triggered before the response is given from the AI Provider plugin.
There are three important events that are currently available in the AI module. One that is triggered before the request is being sent, one that is triggered before the response is given from the AI Provider plugin and one that is triggered after a streaming response is done.
These two makes it possible to change prompts, change responses, log, find bugs etc.
These three makes it possible to change prompts, change responses, log, find bugs etc.
There is also an event that is triggered when an AI provider gets uninstalled/disabled. This is good for 3rd party modules that might rely on a specific provider existing due to 3rd party provider settings.
......@@ -96,12 +96,12 @@ class CountImagesSubscriber implements EventSubscriberInterface {
}
/**
* Change IP after sending.
* Count the images.
*
* @param \Drupal\ai\Event\PreGenerateResponseEvent $event
* @param \Drupal\ai\Event\PostGenerateResponseEvent $event
* The event to count.
*/
public function countImages(PreGenerateResponseEvent $event) {
public function countImages(PostGenerateResponseEvent $event) {
// Only do AI API Explorer.
if ($event->getOperationType() == 'text-to-image') {
$pseudoCounter->countOneMore();
......@@ -112,7 +112,78 @@ class CountImagesSubscriber implements EventSubscriberInterface {
```
## Example #3: Provider Disabled.
## Example #3: Stream finished.
You have a module where you want to log the chat messages, but they happen
to be streaming responses so you can't see it on the post request event.
You do this by creating an event subscriber, something like this.
```php
<?php
namespace Drupal\ai_logging\EventSubscriber;
use Drupal\ai\Event\PostGenerateResponseEvent;
use Drupal\ai\Event\PostStreamingResponseEvent;
use Symfony\Component\EventDispatcher\EventSubscriberInterface;
/**
* Counts images.
*/
class CountImagesSubscriber implements EventSubscriberInterface {
/**
* The state in between.
*
* @var array
*/
private $state = [];
/**
* {@inheritdoc}
*
* @return array
* The post generate response event.
*/
public static function getSubscribedEvents(): array {
return [
PostGenerateResponseEvent::EVENT_NAME => 'storeInput',
PostStreamingResponseEvent::EVENT_NAME => 'storeOutput',
];
}
/**
* Store the input first.
*
* @param \Drupal\ai\Event\PostGenerateResponseEvent $event
* The event to store input.
*/
public function storeInput(PostGenerateResponseEvent $event) {
// Only do AI API Explorer.
if ($event->getOperationType() == 'chat') {
// Store with the unique id of the events happening
$this->state[$event->getRequestThreadId()]['input'] = $event->getInput();
}
}
/**
* Connect the output.
*
* @param \Drupal\ai\Event\PostStreamingResponseEvent $event
* The event to store output.
*/
public function storeOutput(PreGenerateResponseEvent $event) {
// Store the output.
$this->state[$event->getRequestThreadId()]['output'] = $event->getOutput();
}
}
```
## Example #4: Provider Disabled.
You have a third party module that is dependent on a provider called dropai.
......
......@@ -8,6 +8,7 @@ use Drupal\Core\Extension\ModuleHandlerInterface;
use Drupal\Core\Link;
use Drupal\Core\Url;
use Drupal\ai\Event\PostGenerateResponseEvent;
use Drupal\ai\Event\PostStreamingResponseEvent;
use Drupal\ai\OperationType\InputInterface;
use Symfony\Component\EventDispatcher\EventSubscriberInterface;
......@@ -39,6 +40,13 @@ class LogPostRequestEventSubscriber implements EventSubscriberInterface {
*/
protected $moduleHandler;
/**
* UUID to log for streaming.
*
* @var array
*/
protected $streamingUuids = [];
/**
* Constructor.
*/
......@@ -57,6 +65,7 @@ class LogPostRequestEventSubscriber implements EventSubscriberInterface {
public static function getSubscribedEvents(): array {
return [
PostGenerateResponseEvent::EVENT_NAME => 'logPostRequest',
PostStreamingResponseEvent::EVENT_NAME => 'logPostStream',
];
}
......@@ -85,6 +94,33 @@ class LogPostRequestEventSubscriber implements EventSubscriberInterface {
$log->set('output_text', json_encode($event->getOutput()->getRawOutput()));
}
$log->save();
// We store the connection for logging the streamed response.
if ($this->aiSettings->get('prompt_logging_output')) {
if ($event->getOutput()->getNormalized() instanceof \IteratorAggregate) {
$this->streamingUuids[$event->getRequestThreadId()] = $log->id();
}
}
}
}
/**
* If the log was a streaming object, we need to update with the response.
*
* @param \Drupal\ai\Event\PostStreamingResponseEvent $event
* The event to log.
*/
public function logPostStream(PostStreamingResponseEvent $event) {
// If response logging is enabled, add the streamed response.
if ($this->aiSettings->get('prompt_logging_output') && isset($this->streamingUuids[$event->getRequestThreadId()])) {
// Load to update.
$storage = $this->entityTypeManager->getStorage('ai_log');
/** @var \Drupal\ai_logging\Entity\AiLog $log */
$log = $storage->load($this->streamingUuids[$event->getRequestThreadId()]);
if ($log) {
$log->set('output_text', json_encode($event->getOutput()));
$log->save();
}
}
}
......
......@@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Drupal\ai;
use Drupal\Component\Uuid\UuidInterface;
use Drupal\Core\Cache\CacheBackendInterface;
use Drupal\Core\Extension\ModuleHandlerInterface;
use Drupal\Core\Messenger\MessengerInterface;
......@@ -65,6 +66,13 @@ final class AiProviderPluginManager extends DefaultPluginManager {
*/
protected $messenger;
/**
* The UUID service.
*
* @var \Drupal\Component\Uuid\UuidInterface
*/
protected $uuid;
/**
* Constructs the object.
*/
......@@ -74,6 +82,7 @@ final class AiProviderPluginManager extends DefaultPluginManager {
ModuleHandlerInterface $module_handler,
ContainerInterface $container,
MessengerInterface $messenger,
UuidInterface $uuid,
) {
parent::__construct('Plugin/AiProvider', $namespaces, $module_handler, AiProviderInterface::class, AiProvider::class);
$this->alterInfo('ai_provider_info');
......@@ -84,6 +93,7 @@ final class AiProviderPluginManager extends DefaultPluginManager {
$this->moduleHandler = $module_handler;
$this->configFactory = $container->get('config.factory');
$this->messenger = $messenger;
$this->uuid = $uuid;
}
/**
......@@ -102,7 +112,7 @@ final class AiProviderPluginManager extends DefaultPluginManager {
*/
public function createInstance($plugin_id, array $configuration = []): ProviderProxy {
$plugin = parent::createInstance($plugin_id, $configuration);
return new ProviderProxy($plugin, $this->eventDispatcher, $this->loggerFactory);
return new ProviderProxy($plugin, $this->eventDispatcher, $this->loggerFactory, $this->uuid);
}
/**
......
......@@ -12,6 +12,13 @@ class PostGenerateResponseEvent extends Event {
// The event name.
const EVENT_NAME = 'ai.post_generate_response';
/**
* The request thread id.
*
* @var string
*/
protected $requestThreadId;
/**
* The provider to process.
*
......@@ -78,6 +85,8 @@ class PostGenerateResponseEvent extends Event {
/**
* Constructs the object.
*
* @param string $request_thread_id
* The unique request thread id.
* @param string $provider_id
* The provider to process.
* @param string $operation_type
......@@ -97,7 +106,8 @@ class PostGenerateResponseEvent extends Event {
* @param array $metadata
* The metadata to store for the request.
*/
public function __construct(string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, mixed $output, array $tags = [], array $debug_data = [], array $metadata = []) {
public function __construct(string $request_thread_id, string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, mixed $output, array $tags = [], array $debug_data = [], array $metadata = []) {
$this->requestThreadId = $request_thread_id;
$this->providerId = $provider_id;
$this->configuration = $configuration;
$this->operationType = $operation_type;
......@@ -109,6 +119,16 @@ class PostGenerateResponseEvent extends Event {
$this->metadata = $metadata;
}
/**
* Gets the request thread id.
*
* @return string
* The request thread id.
*/
public function getRequestThreadId() {
return $this->requestThreadId;
}
/**
* Gets the provider.
*
......@@ -179,6 +199,18 @@ class PostGenerateResponseEvent extends Event {
return $this->tags;
}
/**
* Allow to set a new tag.
*
* @param string $tag
* The tag.
* @param mixed $value
* The value.
*/
public function setTag(string $tag, mixed $value) {
$this->tags[$tag] = $value;
}
/**
* Gets the debug data.
*
......
<?php
namespace Drupal\ai\Event;
use Drupal\Component\EventDispatcher\Event;
/**
* For collecting the results post streaming.
*
* This event should be used in conjunction with the PostGenerateResponseEvent
* using the request thread id to connect the two events. There is no
* manipulation of the data in this event, it is just for collecting the final
* results.
*/
class PostStreamingResponseEvent extends Event {
// The event name.
const EVENT_NAME = 'ai.post_streaming_response';
/**
* The request thread id id.
*
* @var string
*/
protected $requestThreadId;
/**
* The output for the request.
*
* @var mixed
*/
protected $role;
/**
* The output for the request.
*
* @var mixed
* The output for the request.
*/
protected $output;
/**
* The metadata to store for the request.
*
* @var array
*/
protected array $metadata;
/**
* Constructs the object.
*
* @param string $request_thread_id
* The unique request thread id.
* @param mixed $output
* The output for the request.
* @param array $metadata
* The metadata to store for the request.
*/
public function __construct(string $request_thread_id, $output, array $metadata = []) {
$this->requestThreadId = $request_thread_id;
$this->output = $output;
$this->metadata = $metadata;
}
/**
* Gets the request thread id.
*
* @return string
* The request thread id.
*/
public function getRequestThreadId() {
return $this->requestThreadId;
}
/**
* Gets the output.
*
* @return mixed
* The output.
*/
public function getOutput() {
return $this->output;
}
/**
* Get all the metadata.
*
* @return array
* All the metadata.
*/
public function getAllMetadata(): array {
return $this->metadata;
}
/**
* Set all metadata replacing existing contents.
*
* @param array $metadata
* All the metadata.
*/
public function setAllMetadata(array $metadata): void {
$this->metadata = $metadata;
}
/**
* Get specific metadata by key.
*
* @param string $metadata_key
* The key of the metadata to return.
*
* @return mixed
* The metadata for the provided key.
*/
public function getMetadata(string $metadata_key): mixed {
return $this->metadata[$metadata_key];
}
/**
* Add to the metadata by key.
*
* @param string $key
* The key.
* @param mixed $value
* The value.
*/
public function setMetadata(string $key, mixed $value): void {
$this->metadata[$key] = $value;
}
}
......@@ -12,6 +12,13 @@ class PreGenerateResponseEvent extends Event {
// The event name.
const EVENT_NAME = 'ai.pre_generate_response';
/**
* The request thread id.
*
* @var string
*/
protected $requestThreadId;
/**
* The provider to process.
*
......@@ -85,6 +92,8 @@ class PreGenerateResponseEvent extends Event {
/**
* Constructs the object.
*
* @param string $request_thread_id
* The unique request thread id.
* @param string $provider_id
* The provider to process.
* @param string $operation_type
......@@ -102,7 +111,8 @@ class PreGenerateResponseEvent extends Event {
* @param array $metadata
* The metadata to store for the request.
*/
public function __construct(string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, array $tags = [], array $debug_data = [], array $metadata = []) {
public function __construct(string $request_thread_id, string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, array $tags = [], array $debug_data = [], array $metadata = []) {
$this->requestThreadId = $request_thread_id;
$this->providerId = $provider_id;
$this->configuration = $configuration;
$this->operationType = $operation_type;
......@@ -113,6 +123,16 @@ class PreGenerateResponseEvent extends Event {
$this->metadata = $metadata;
}
/**
* Gets the request thread id.
*
* @return string
* The request thread id.
*/
public function getRequestThreadId() {
return $this->requestThreadId;
}
/**
* Gets the provider.
*
......
......@@ -2,11 +2,16 @@
namespace Drupal\ai\OperationType\Chat;
use Drupal\ai\Event\PostStreamingResponseEvent;
use Drupal\ai\Traits\OperationType\EventDispatcherTrait;
/**
* Streamed chat message iterator interface.
*/
abstract class StreamedChatMessageIterator implements StreamedChatMessageIteratorInterface {
use EventDispatcherTrait;
/**
* The iterator.
*
......@@ -14,6 +19,21 @@ abstract class StreamedChatMessageIterator implements StreamedChatMessageIterato
*/
protected $iterator;
/**
* The messages.
*
* @var array
* The stream chat messages.
*/
protected $messages = [];
/**
* The request thread id.
*
* @var string
*/
protected $requestThreadId;
/**
* Constructor.
*/
......@@ -21,4 +41,59 @@ abstract class StreamedChatMessageIterator implements StreamedChatMessageIterato
$this->iterator = $iterator;
}
/**
* Trigger the event on streaming finished.
*/
public function triggerEvent(): void {
// Create a ChatMessage out of it all.
$role = '';
$message_text = '';
foreach ($this->messages as $message) {
if (!empty($message->getRole()) && empty($role)) {
$role = $message->getRole();
}
if (!empty($message->getText())) {
$message_text .= $message->getText();
}
}
$message = [
'role' => $role,
'message' => $message_text,
];
// Dispatch the event.
$event = new PostStreamingResponseEvent($this->requestThreadId, $message, []);
$this->getEventDispatcher()->dispatch($event, PostStreamingResponseEvent::EVENT_NAME);
}
/**
* {@inheritdoc}
*/
public function setRequestThreadId(string $request_thread_id): void {
$this->requestThreadId = $request_thread_id;
}
/**
* {@inheritdoc}
*/
public function getRequestThreadId(): string {
return $this->requestThreadId;
}
/**
* {@inheritdoc}
*/
public function createStreamedChatMessage(string $role, string $message, array $metadata): StreamedChatMessageInterface {
$message = new StreamedChatMessage($role, $message, $metadata);
$this->messages[] = $message;
return $message;
}
/**
* {@inheritdoc}
*/
public function getStreamChatMessages(): array {
return $this->messages;
}
}
......@@ -9,4 +9,48 @@ interface StreamedChatMessageIteratorInterface extends \IteratorAggregate {
public function __construct(\IteratorAggregate $iterator);
/**
* Set an request thread id.
*
* @param string $request_thread_id
* The request thread id.
*/
public function setRequestThreadId(string $request_thread_id): void;
/**
* Get the request thread id.
*
* @return string
* The request thread id.
*/
public function getRequestThreadId(): string;
/**
* Sets on stream chat message.
*
* @param string $role
* The role.
* @param string $message
* The message.
* @param array $metadata
* The metadata.
*
* @return \Drupal\ai\OperationType\Chat\StreamedChatMessageInterface
* The streamed chat message.
*/
public function createStreamedChatMessage(string $role, string $message, array $metadata): StreamedChatMessageInterface;
/**
* Gets the stream chat messages.
*
* @return array
* The stream chat messages.
*/
public function getStreamChatMessages(): array;
/**
* Trigger the event on streaming finished.
*/
public function triggerEvent(): void;
}
......@@ -2,6 +2,7 @@
namespace Drupal\ai\Plugin;
use Drupal\Component\Uuid\UuidInterface;
use Drupal\Core\Logger\LoggerChannelFactoryInterface;
use Drupal\ai\Base\AiProviderClientBase;
use Drupal\ai\Event\PostGenerateResponseEvent;
......@@ -44,6 +45,13 @@ class ProviderProxy {
*/
protected $loggerFactory;
/**
* The UUID service.
*
* @var \Drupal\Component\Uuid\UuidInterface
*/
protected $uuid;
/**
* PluginLoggingProxy constructor.
*
......@@ -53,11 +61,14 @@ class ProviderProxy {
* The event dispatcher.
* @param \Drupal\Core\Logger\LoggerChannelFactoryInterface $logger_factory
* The logger factory.
* @param \Drupal\Component\Uuid\UuidInterface $uuid
* The UUID service.
*/
public function __construct(AiProviderClientBase $plugin, EventDispatcherInterface $event_dispatcher, LoggerChannelFactoryInterface $logger_factory) {
public function __construct(AiProviderClientBase $plugin, EventDispatcherInterface $event_dispatcher, LoggerChannelFactoryInterface $logger_factory, UuidInterface $uuid) {
$this->plugin = $plugin;
$this->eventDispatcher = $event_dispatcher;
$this->loggerFactory = $logger_factory;
$this->uuid = $uuid;
}
/**
......@@ -138,8 +149,11 @@ class ProviderProxy {
$this->plugin->setTag($tag);
}
// Create a unique event id.
$event_id = $this->uuid->generate();
// Invoke the pre generate response event.
$pre_generate_event = new PreGenerateResponseEvent($this->plugin->getPluginId(), $operation_type, $this->plugin->configuration, $arguments[0], $arguments[1], $this->plugin->getTags(), $this->plugin->getDebugData());
$pre_generate_event = new PreGenerateResponseEvent($event_id, $this->plugin->getPluginId(), $operation_type, $this->plugin->configuration, $arguments[0], $arguments[1], $this->plugin->getTags(), $this->plugin->getDebugData());
$this->eventDispatcher->dispatch($pre_generate_event, PreGenerateResponseEvent::EVENT_NAME);
// Get the possible new auth, configuration and input from the event.
......@@ -196,11 +210,16 @@ class ProviderProxy {
}
// Invoke the post generate response event.
$post_generate_event = new PostGenerateResponseEvent($this->plugin->getPluginId(), $operation_type, $this->plugin->configuration, $arguments[0], $arguments[1], $response, $this->plugin->getTags(), $this->plugin->getDebugData(), $pre_generate_event->getAllMetadata());
$post_generate_event = new PostGenerateResponseEvent($event_id, $this->plugin->getPluginId(), $operation_type, $this->plugin->configuration, $arguments[0], $arguments[1], $response, $this->plugin->getTags(), $this->plugin->getDebugData(), $pre_generate_event->getAllMetadata());
$this->eventDispatcher->dispatch($post_generate_event, PostGenerateResponseEvent::EVENT_NAME);
// Get a potential new response from the event.
$response = $post_generate_event->getOutput();
// Since we need to attach events on streaming responses as well.
if ($response->getNormalized() instanceof \IteratorAggregate) {
$response->getNormalized()->setRequestThreadId($event_id);
}
// Return the response.
return $response;
}
......
<?php
namespace Drupal\ai\Traits\OperationType;
use Symfony\Component\EventDispatcher\EventDispatcherInterface;
/**
* Event dispatcher trait for operation types.
*
* @package Drupal\ai\Traits\OperationType
*/
trait EventDispatcherTrait {
/**
* {@inheritdoc}
*/
public function getEventDispatcher(): EventDispatcherInterface {
return \Drupal::service('event_dispatcher');
}
}
......@@ -3,26 +3,13 @@
namespace Drupal\Tests\ai\Mock;
use Drupal\ai\OperationType\Chat\StreamedChatMessage;
use Drupal\ai\OperationType\Chat\StreamedChatMessageIterator;
use Drupal\ai\OperationType\Chat\StreamedChatMessageIteratorInterface;
/**
* Mock chat iterator for testing.
*/
class MockStreamedChatIterator implements StreamedChatMessageIteratorInterface {
/**
* The iterator.
*
* @var \IteratorAggregate
*/
private $iterator;
/**
* {@inheritdoc}
*/
public function __construct(\IteratorAggregate $iterator) {
$this->iterator = $iterator;
}
class MockStreamedChatIterator extends StreamedChatMessageIterator implements StreamedChatMessageIteratorInterface {
/**
* Get the iterator.
......
......@@ -114,6 +114,14 @@ class PostGenerateResponseEventTest extends TestCase {
$this->assertEquals('It is not!', $event->getOutput());
}
/**
* Test get event id.
*/
public function testEventId(): void {
$event = $this->getEvent();
$this->assertEquals('unique_id', $event->getRequestThreadId());
}
/**
* Helper function to get the events.
*
......@@ -121,7 +129,7 @@ class PostGenerateResponseEventTest extends TestCase {
* The event.
*/
public function getEvent(): PostGenerateResponseEvent {
return new PostGenerateResponseEvent('test', 'chat', [
return new PostGenerateResponseEvent('unique_id', 'test', 'chat', [
'test' => 'testing',
],
'This is a test',
......
<?php
namespace Drupal\Tests\ai\Unit\Event;
use Drupal\ai\Event\PostStreamingResponseEvent;
use PHPUnit\Framework\TestCase;
/**
* Tests that the event function works.
*
* @group ai
* @covers \Drupal\ai\Event\PostStreamingResponseEvent
*/
class PostStreamingResponseEventTest extends TestCase {
/**
* Test get event id.
*/
public function testEventId(): void {
$event = $this->getEvent();
$this->assertEquals('unique_id', $event->getRequestThreadId());
}
/**
* Test get response.
*/
public function testResponse(): void {
$event = $this->getEvent();
$this->assertEquals('test', $event->getOutput());
}
/**
* Helper function to get the events.
*
* @return \Drupal\ai\Event\PostStreamingResponseEvent|\PHPUnit\Framework\MockObject\MockObject
* The event.
*/
public function getEvent(): PostStreamingResponseEvent {
return new PostStreamingResponseEvent('unique_id', 'test', [
'test' => 'test',
]);
}
}
......@@ -97,6 +97,14 @@ class PreGenerateResponseEventTest extends TestCase {
], $event->getConfiguration());
}
/**
* Test get event id.
*/
public function testEventId(): void {
$event = $this->getEvent();
$this->assertEquals('unique_id', $event->getRequestThreadId());
}
/**
* Helper function to get the events.
*
......@@ -105,6 +113,7 @@ class PreGenerateResponseEventTest extends TestCase {
*/
public function getEvent(): PreGenerateResponseEvent {
return new PreGenerateResponseEvent(
'unique_id',
'test',
'chat',
[
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment