Skip to content
Snippets Groups Projects

Enhance AzureProvider to support tools and structured JSON responses

Files
4
@@ -2,11 +2,6 @@
namespace Drupal\ai_provider_azure\Plugin\AiProvider;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\File\FileExists;
use Drupal\Core\Plugin\ContainerFactoryPluginInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use Drupal\ai\Attribute\AiProvider;
use Drupal\ai\Base\AiProviderClientBase;
use Drupal\ai\Enum\AiProviderCapability;
@@ -19,6 +14,7 @@ use Drupal\ai\OperationType\Chat\ChatInput;
use Drupal\ai\OperationType\Chat\ChatInterface;
use Drupal\ai\OperationType\Chat\ChatMessage;
use Drupal\ai\OperationType\Chat\ChatOutput;
use Drupal\ai\OperationType\Chat\Tools\ToolsFunctionOutput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInput;
use Drupal\ai\OperationType\Embeddings\EmbeddingsInterface;
use Drupal\ai\OperationType\Embeddings\EmbeddingsOutput;
@@ -36,6 +32,12 @@ use Drupal\ai\OperationType\TextToSpeech\TextToSpeechOutput;
use Drupal\ai\Traits\OperationType\ChatTrait;
use Drupal\ai_provider_azure\AzureChatMessageIterator;
use Drupal\ai_provider_azure\Client\LightweightProviderClient;
use Drupal\Component\Serialization\Json;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\File\FileExists;
use Drupal\Core\Plugin\ContainerFactoryPluginInterface;
use Drupal\Core\StringTranslation\StringTranslationTrait;
use Drupal\Core\StringTranslation\TranslatableMarkup;
use OpenAI\Client;
use Symfony\Component\DependencyInjection\ContainerInterface;
use Symfony\Component\Yaml\Yaml;
@@ -328,22 +330,59 @@ class AzureProvider extends AiProviderClientBase implements
];
}
}
$chat_input[] = [
$new_message = [
'role' => $message->getRole(),
'content' => $content,
];
// If its a tools response.
if ($message->getToolsId()) {
$new_message['tool_call_id'] = $message->getToolsId();
}
// If we want the results from some older tools call.
if ($message->getTools()) {
$new_message['tool_calls'] = $message->getRenderedTools();
}
$chat_input[] = $new_message;
}
}
$payload = [
'messages' => $chat_input,
];
// If we want to add tools to the input.
if (method_exists($input, 'getChatTools') && $input->getChatTools()) {
$payload['tools'] = $input->getChatTools()->renderToolsArray();
foreach ($payload['tools'] as $key => $tool) {
$payload['tools'][$key]['function']['strict'] = FALSE;
}
}
// Check for structured json schemas.
if (method_exists($input, 'getChatStructuredJsonSchema') && $input->getChatStructuredJsonSchema()) {
$payload['response_format'] = [
'type' => 'json_schema',
'json_schema' => $input->getChatStructuredJsonSchema(),
];
}
try {
if ($this->streamed) {
$response = $this->client->chat()->createStreamed($payload);
$message = new AzureChatMessageIterator($response);
}
else {
$response = $this->client->chat()->create($payload)->toArray();
// If tools are generated.
$tools = [];
if (!empty($response['choices'][0]['message']['tool_calls'])) {
foreach ($response['choices'][0]['message']['tool_calls'] as $tool) {
$arguments = Json::decode($tool['function']['arguments']);
$tools[] = new ToolsFunctionOutput($input->getChatTools()->getFunctionByName($tool['function']['name']), $tool['id'], $arguments);
}
}
$message = '';
}
}
catch (\Exception $e) {
@@ -360,7 +399,6 @@ class AzureProvider extends AiProviderClientBase implements
}
}
$message = '';
if ($this->streamed) {
$message = new AzureChatMessageIterator($response);
}
@@ -368,7 +406,10 @@ class AzureProvider extends AiProviderClientBase implements
$consumer = $this->messageConsumers[$info['custom_consumer']] ?? NULL;
// If no consumer, we consume as usual.
if (!$consumer) {
$message = new ChatMessage($response['choices'][0]['message']['role'], $response['choices'][0]['message']['content']);
$message = new ChatMessage($response['choices'][0]['message']['role'], $response['choices'][0]['message']['content'] ?? '', []);
if (!empty($tools)) {
$message->setTools($tools);
}
}
else {
// Otherwise check if its multiple or not.
Loading