Skip to content
Snippets Groups Projects
Commit 45cb1967 authored by Marcus Johansson's avatar Marcus Johansson
Browse files

Issue #3489690: Add an event that triggers on chat iterator consumption

parent c5460672
No related branches found
No related tags found
1 merge request!295Resolve #3489690 "Add an event"
# Events
There are two important events that are currently available in the AI module. One that is triggered before the request is being sent and one that is triggered before the response is given from the AI Provider plugin.
There are three important events that are currently available in the AI module. One that is triggered before the request is being sent, one that is triggered before the response is given from the AI Provider plugin and one that is triggered after a streaming response is done.
These two makes it possible to change prompts, change responses, log, find bugs etc.
These three makes it possible to change prompts, change responses, log, find bugs etc.
## Example #1: Pre request.
......@@ -62,7 +62,7 @@ class IpCheckSubscriber implements EventSubscriberInterface {
```
## Example #1: Post request.
## Example #2: Post request.
You have a module where you want to log how many images you created in total.
......@@ -94,12 +94,12 @@ class CountImagesSubscriber implements EventSubscriberInterface {
}
/**
* Change IP after sending.
* Count the images.
*
* @param \Drupal\ai\Event\PreGenerateResponseEvent $event
* @param \Drupal\ai\Event\PostGenerateResponseEvent $event
* The event to count.
*/
public function countImages(PreGenerateResponseEvent $event) {
public function countImages(PostGenerateResponseEvent $event) {
// Only do AI API Explorer.
if ($event->getOperationType() == 'text-to-image') {
$pseudoCounter->countOneMore();
......@@ -109,3 +109,74 @@ class CountImagesSubscriber implements EventSubscriberInterface {
}
```
## Example #3: Stream finished.
You have a module where you want to log the chat messages, but they happen
to be streaming responses so you can't see it on the post request event.
You do this by creating an event subscriber, something like this.
```php
<?php
namespace Drupal\ai_logging\EventSubscriber;
use Drupal\ai\Event\PostGenerateResponseEvent;
use Drupal\ai\Event\PostStreamingResponseEvent;
use Symfony\Component\EventDispatcher\EventSubscriberInterface;
/**
* Counts images.
*/
class CountImagesSubscriber implements EventSubscriberInterface {
/**
* The state in between.
*
* @var array
*/
private $state = [];
/**
* {@inheritdoc}
*
* @return array
* The post generate response event.
*/
public static function getSubscribedEvents(): array {
return [
PostGenerateResponseEvent::EVENT_NAME => 'storeInput',
PostStreamingResponseEvent::EVENT_NAME => 'storeOutput',
];
}
/**
* Store the input first.
*
* @param \Drupal\ai\Event\PostGenerateResponseEvent $event
* The event to store input.
*/
public function storeInput(PostGenerateResponseEvent $event) {
// Only do AI API Explorer.
if ($event->getOperationType() == 'chat') {
// Store with the unique id of the events happening
$this->state[$event->getRequestThreadId()]['input'] = $event->getInput();
}
}
/**
* Connect the output.
*
* @param \Drupal\ai\Event\PostStreamingResponseEvent $event
* The event to store output.
*/
public function stoureOutput(PreGenerateResponseEvent $event) {
// Store the output.
$this->state[$event->getRequestThreadId()]['output'] = $event->getOutput();
}
}
```
......@@ -97,7 +97,7 @@ class LogPostRequestEventSubscriber implements EventSubscriberInterface {
// We store the connection for logging the streamed response.
if ($this->aiSettings->get('prompt_logging_output')) {
if ($event->getOutput()->getNormalized() instanceof \IteratorAggregate) {
$this->streamingUuids[$event->getEventId()] = $log->id();
$this->streamingUuids[$event->getRequestThreadId()] = $log->id();
}
}
}
......@@ -111,11 +111,11 @@ class LogPostRequestEventSubscriber implements EventSubscriberInterface {
*/
public function logPostStream(PostStreamingResponseEvent $event) {
// If response logging is enabled, add the streamed response.
if ($this->aiSettings->get('prompt_logging_output') && isset($this->streamingUuids[$event->getEventId()])) {
if ($this->aiSettings->get('prompt_logging_output') && isset($this->streamingUuids[$event->getRequestThreadId()])) {
// Load to update.
$storage = $this->entityTypeManager->getStorage('ai_log');
/** @var \Drupal\ai_logging\Entity\AiLog $log */
$log = $storage->load($this->streamingUuids[$event->getEventId()]);
$log = $storage->load($this->streamingUuids[$event->getRequestThreadId()]);
if ($log) {
$log->set('output_text', json_encode($event->getOutput()));
......
......@@ -13,11 +13,11 @@ class PostGenerateResponseEvent extends Event {
const EVENT_NAME = 'ai.post_generate_response';
/**
* The event id.
* The request thread id.
*
* @var string
*/
protected $eventId;
protected $requestThreadId;
/**
* The provider to process.
......@@ -85,8 +85,8 @@ class PostGenerateResponseEvent extends Event {
/**
* Constructs the object.
*
* @param string $event_id
* The unique event id.
* @param string $request_thread_id
* The unique request thread id.
* @param string $provider_id
* The provider to process.
* @param string $operation_type
......@@ -106,8 +106,8 @@ class PostGenerateResponseEvent extends Event {
* @param array $metadata
* The metadata to store for the request.
*/
public function __construct(string $event_id, string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, mixed $output, array $tags = [], array $debug_data = [], array $metadata = []) {
$this->eventId = $event_id;
public function __construct(string $request_thread_id, string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, mixed $output, array $tags = [], array $debug_data = [], array $metadata = []) {
$this->requestThreadId = $request_thread_id;
$this->providerId = $provider_id;
$this->configuration = $configuration;
$this->operationType = $operation_type;
......@@ -120,13 +120,13 @@ class PostGenerateResponseEvent extends Event {
}
/**
* Gets the event id.
* Gets the request thread id.
*
* @return string
* The event id.
* The request thread id.
*/
public function getEventId() {
return $this->eventId;
public function getRequestThreadId() {
return $this->requestThreadId;
}
/**
......
......@@ -8,8 +8,9 @@ use Drupal\Component\EventDispatcher\Event;
* For collecting the results post streaming.
*
* This event should be used in conjunction with the PostGenerateResponseEvent
* using the event id to connect the two events. There is no manipulation of
* the data in this event, it is just for collecting the final results.
* using the request thread id to connect the two events. There is no
* manipulation of the data in this event, it is just for collecting the final
* results.
*/
class PostStreamingResponseEvent extends Event {
......@@ -17,11 +18,11 @@ class PostStreamingResponseEvent extends Event {
const EVENT_NAME = 'ai.post_streaming_response';
/**
* The event id.
* The request thread id id.
*
* @var string
*/
protected $eventId;
protected $requestThreadId;
/**
* The output for the request.
......@@ -48,27 +49,27 @@ class PostStreamingResponseEvent extends Event {
/**
* Constructs the object.
*
* @param string $event_id
* The unique event id.
* @param string $request_thread_id
* The unique request thread id.
* @param mixed $output
* The output for the request.
* @param array $metadata
* The metadata to store for the request.
*/
public function __construct(string $event_id, $output, array $metadata = []) {
$this->eventId = $event_id;
public function __construct(string $request_thread_id, $output, array $metadata = []) {
$this->requestThreadId = $request_thread_id;
$this->output = $output;
$this->metadata = $metadata;
}
/**
* Gets the event id.
* Gets the request thread id.
*
* @return string
* The event id.
* The request thread id.
*/
public function getEventId() {
return $this->eventId;
public function getRequestThreadId() {
return $this->requestThreadId;
}
/**
......
......@@ -13,11 +13,11 @@ class PreGenerateResponseEvent extends Event {
const EVENT_NAME = 'ai.pre_generate_response';
/**
* The event id.
* The request thread id.
*
* @var string
*/
protected $eventId;
protected $requestThreadId;
/**
* The provider to process.
......@@ -92,8 +92,8 @@ class PreGenerateResponseEvent extends Event {
/**
* Constructs the object.
*
* @param string $event_id
* The unique event id.
* @param string $request_thread_id
* The unique request thread id.
* @param string $provider_id
* The provider to process.
* @param string $operation_type
......@@ -111,8 +111,8 @@ class PreGenerateResponseEvent extends Event {
* @param array $metadata
* The metadata to store for the request.
*/
public function __construct(string $event_id, string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, array $tags = [], array $debug_data = [], array $metadata = []) {
$this->eventId = $event_id;
public function __construct(string $request_thread_id, string $provider_id, string $operation_type, array $configuration, mixed $input, string $model_id, array $tags = [], array $debug_data = [], array $metadata = []) {
$this->requestThreadId = $request_thread_id;
$this->providerId = $provider_id;
$this->configuration = $configuration;
$this->operationType = $operation_type;
......@@ -124,13 +124,13 @@ class PreGenerateResponseEvent extends Event {
}
/**
* Gets the event id.
* Gets the request thread id.
*
* @return string
* The event id.
* The request thread id.
*/
public function getEventId() {
return $this->eventId;
public function getRequestThreadId() {
return $this->requestThreadId;
}
/**
......
......@@ -28,11 +28,11 @@ abstract class StreamedChatMessageIterator implements StreamedChatMessageIterato
protected $messages = [];
/**
* The event id.
* The request thread id.
*
* @var string
*/
protected $eventId;
protected $requestThreadId;
/**
* Constructor.
......@@ -62,22 +62,22 @@ abstract class StreamedChatMessageIterator implements StreamedChatMessageIterato
];
// Dispatch the event.
$event = new PostStreamingResponseEvent($this->eventId, $message, []);
$event = new PostStreamingResponseEvent($this->requestThreadId, $message, []);
$this->getEventDispatcher()->dispatch($event, PostStreamingResponseEvent::EVENT_NAME);
}
/**
* {@inheritdoc}
*/
public function setEventId(string $eventId): void {
$this->eventId = $eventId;
public function setRequestThreadId(string $request_thread_id): void {
$this->requestThreadId = $request_thread_id;
}
/**
* {@inheritdoc}
*/
public function getEventId(): string {
return $this->eventId;
public function getRequestThreadId(): string {
return $this->requestThreadId;
}
/**
......
......@@ -10,20 +10,20 @@ interface StreamedChatMessageIteratorInterface extends \IteratorAggregate {
public function __construct(\IteratorAggregate $iterator);
/**
* Set an event id.
* Set an request thread id.
*
* @param string $eventId
* The event id.
* @param string $request_thread_id
* The request thread id.
*/
public function setEventId(string $eventId): void;
public function setRequestThreadId(string $request_thread_id): void;
/**
* Get the event id.
* Get the request thread id.
*
* @return string
* The event id.
* The request thread id.
*/
public function getEventId(): string;
public function getRequestThreadId(): string;
/**
* Sets on stream chat message.
......
......@@ -217,7 +217,7 @@ class ProviderProxy {
// Since we need to attach events on streaming responses as well.
if ($response->getNormalized() instanceof \IteratorAggregate) {
$response->getNormalized()->setEventId($event_id);
$response->getNormalized()->setRequestThreadId($event_id);
}
// Return the response.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment