From 92b71aefc08a6efb443eddb56dc9d8d31a307664 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 22 May 2026 18:23:29 +0200 Subject: [PATCH] Fabric IQ and Data Gen samples --- sdk/ai/azure-ai-agents/assets.json | 2 +- .../azure/ai/agents/tools/FabricIQAsync.java | 9 +- .../azure/ai/agents/tools/FabricIQSync.java | 9 +- .../com/azure/ai/agents/ClientTestBase.java | 4 +- .../ai/agents/tools/FabricIQSamplesTests.java | 110 ++++++- sdk/ai/azure-ai-projects/assets.json | 2 +- ...DataGenerationJobWithEvaluationSample.java | 308 ++++++++++++++++++ .../com/azure/ai/projects/ClientTestBase.java | 4 +- .../com/azure/ai/projects/SamplesTests.java | 120 +++++++ 9 files changed, 551 insertions(+), 17 deletions(-) create mode 100644 sdk/ai/azure-ai-projects/src/samples/java/com/azure/ai/projects/DataGenerationJobWithEvaluationSample.java diff --git a/sdk/ai/azure-ai-agents/assets.json b/sdk/ai/azure-ai-agents/assets.json index bf21adf5156b..bbeb80ce0ad6 100644 --- a/sdk/ai/azure-ai-agents/assets.json +++ b/sdk/ai/azure-ai-agents/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "java", "TagPrefix": "java/ai/azure-ai-agents", - "Tag": "java/ai/azure-ai-agents_5238437e7d" + "Tag": "java/ai/azure-ai-agents_767bd42da6" } diff --git a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java index 5b4c67bc5cb2..e22571576b8f 100644 --- a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java +++ b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java @@ -28,6 +28,7 @@ *
  • FOUNDRY_PROJECT_ENDPOINT - The Azure AI Project endpoint.
  • *
  • FOUNDRY_MODEL_NAME - The model deployment name.
  • *
  • FABRIC_IQ_PROJECT_CONNECTION_ID - The FabricIQ connection ID.
  • + *
  • FABRIC_IQ_USER_INPUT - Optional. The natural-language question to send to the agent.
  • * */ public class FabricIQAsync { @@ -35,6 +36,8 @@ public static void main(String[] args) { String endpoint = Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT"); String model = Configuration.getGlobalConfiguration().get("FOUNDRY_MODEL_NAME"); String fabricIqConnectionId = Configuration.getGlobalConfiguration().get("FABRIC_IQ_PROJECT_CONNECTION_ID"); + String userInput = Configuration.getGlobalConfiguration().get("FABRIC_IQ_USER_INPUT", + "Use FabricIQ to summarize the available enterprise context."); AgentsClientBuilder builder = new AgentsClientBuilder() .credential(new DefaultAzureCredentialBuilder().build()) @@ -51,10 +54,10 @@ public static void main(String[] args) { .setDescription("Use FabricIQ to answer questions grounded in enterprise data."); PromptAgentDefinition agentDefinition = new PromptAgentDefinition(model) - .setInstructions("You are a data assistant that can use FabricIQ for grounded enterprise answers.") + .setInstructions("Use the available Fabric IQ tools to answer questions and perform tasks.") .setTools(Collections.singletonList(fabricIqTool)); - agentsAsyncClient.createAgentVersion("fabric-iq-async-agent", agentDefinition) + agentsAsyncClient.createAgentVersion("MyAgent", agentDefinition) .flatMap(agent -> { agentRef.set(agent); System.out.printf("Agent created: %s (version %s)%n", agent.getName(), agent.getVersion()); @@ -65,7 +68,7 @@ public static void main(String[] args) { return responsesAsyncClient.createAzureResponse( new AzureCreateResponseOptions().setAgentReference(agentReference), ResponseCreateParams.builder() - .input("Use FabricIQ to summarize the available enterprise context.")); + .input(userInput)); }) .doOnNext(response -> System.out.println("Response: " + response.output())) .then(Mono.defer(() -> { diff --git a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java index 34d7272b8c06..e0b397f4fab9 100644 --- a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java +++ b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java @@ -26,6 +26,7 @@ *
  • FOUNDRY_PROJECT_ENDPOINT - The Azure AI Project endpoint.
  • *
  • FOUNDRY_MODEL_NAME - The model deployment name.
  • *
  • FABRIC_IQ_PROJECT_CONNECTION_ID - The FabricIQ connection ID.
  • + *
  • FABRIC_IQ_USER_INPUT - Optional. The natural-language question to send to the agent.
  • * */ public class FabricIQSync { @@ -33,6 +34,8 @@ public static void main(String[] args) { String endpoint = Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT"); String model = Configuration.getGlobalConfiguration().get("FOUNDRY_MODEL_NAME"); String fabricIqConnectionId = Configuration.getGlobalConfiguration().get("FABRIC_IQ_PROJECT_CONNECTION_ID"); + String userInput = Configuration.getGlobalConfiguration().get("FABRIC_IQ_USER_INPUT", + "Use FabricIQ to summarize the available enterprise context."); AgentsClientBuilder builder = new AgentsClientBuilder() .credential(new DefaultAzureCredentialBuilder().build()) @@ -52,10 +55,10 @@ public static void main(String[] args) { // END: com.azure.ai.agents.define_fabric_iq PromptAgentDefinition agentDefinition = new PromptAgentDefinition(model) - .setInstructions("You are a data assistant that can use FabricIQ for grounded enterprise answers.") + .setInstructions("Use the available Fabric IQ tools to answer questions and perform tasks.") .setTools(Collections.singletonList(fabricIqTool)); - AgentVersionDetails agent = agentsClient.createAgentVersion("fabric-iq-agent", agentDefinition); + AgentVersionDetails agent = agentsClient.createAgentVersion("MyAgent", agentDefinition); System.out.printf("Agent created: %s (version %s)%n", agent.getName(), agent.getVersion()); try { @@ -65,7 +68,7 @@ public static void main(String[] args) { Response response = responsesClient.createAzureResponse( new AzureCreateResponseOptions().setAgentReference(agentReference), ResponseCreateParams.builder() - .input("Use FabricIQ to summarize the available enterprise context.")); + .input(userInput)); System.out.println("Response: " + response.output()); } finally { diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java index 04ea90051026..097a3284ba8f 100644 --- a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java +++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java @@ -42,10 +42,10 @@ protected AgentsClientBuilder getClientBuilder(HttpClient httpClient, AgentsServ builder.endpoint("https://localhost:8080").credential(new MockTokenCredential()); } else if (testMode == TestMode.RECORD) { builder.addPolicy(interceptorManager.getRecordPolicy()) - .endpoint(Configuration.getGlobalConfiguration().get("AZURE_AGENTS_ENDPOINT")) + .endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT")) .credential(new DefaultAzureCredentialBuilder().build()); } else { - builder.endpoint(Configuration.getGlobalConfiguration().get("AZURE_AGENTS_ENDPOINT")) + builder.endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT")) .credential(new DefaultAzureCredentialBuilder().build()); } diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java index eda638e6feaf..77c421fc1204 100644 --- a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java +++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java @@ -3,23 +3,42 @@ package com.azure.ai.agents.tools; +import com.azure.ai.agents.AgentsAsyncClient; +import com.azure.ai.agents.AgentsClient; +import com.azure.ai.agents.AgentsClientBuilder; import com.azure.ai.agents.AgentsServiceVersion; import com.azure.ai.agents.ClientTestBase; +import com.azure.ai.agents.ResponsesAsyncClient; +import com.azure.ai.agents.ResponsesClient; import com.azure.core.http.HttpClient; -import org.junit.jupiter.api.Disabled; +import com.azure.core.test.TestMode; +import com.azure.core.util.Configuration; +import com.azure.ai.agents.models.AgentReference; +import com.azure.ai.agents.models.AgentVersionDetails; +import com.azure.ai.agents.models.AzureCreateResponseOptions; +import com.azure.ai.agents.models.FabricIQPreviewTool; +import com.azure.ai.agents.models.PromptAgentDefinition; +import com.openai.models.responses.Response; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.ResponseStatus; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import static com.azure.core.test.TestProxyTestBase.getHttpClients; public class FabricIQSamplesTests extends ClientTestBase { private static final String DISPLAY_NAME_WITH_ARGUMENTS = "{displayName} with [{arguments}]"; + private static final String DEFAULT_USER_INPUT = "Show weather events in Texas."; static Stream getTestParameters() { List argumentsList = new ArrayList<>(); @@ -27,17 +46,98 @@ static Stream getTestParameters() { return argumentsList.stream(); } - @Disabled("Requires FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME.") + @Timeout(value = 5, unit = TimeUnit.MINUTES) @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @MethodSource("getTestParameters") public void fabricIqSyncSample(HttpClient httpClient, AgentsServiceVersion serviceVersion) { - Assertions.fail("Enable after providing FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME."); + AgentsClientBuilder builder = getClientBuilder(httpClient, serviceVersion); + AgentsClient agentsClient = builder.buildAgentsClient(); + ResponsesClient responsesClient = builder.buildResponsesClient(); + + String agentName = testResourceNamer.randomName("fabric-iq-sync-", 40); + AgentVersionDetails agent = null; + + try { + agent = agentsClient.createAgentVersion(agentName, createAgentDefinition()); + Assertions.assertNotNull(agent); + + AgentReference agentReference = new AgentReference(agent.getName()).setVersion(agent.getVersion()); + Response response = responsesClient.createAzureResponse( + new AzureCreateResponseOptions().setAgentReference(agentReference), + ResponseCreateParams.builder().input(getUserInput())); + + assertCompletedResponse(response); + } finally { + if (agent != null) { + agentsClient.deleteAgentVersion(agent.getName(), agent.getVersion()); + } + } } - @Disabled("Requires FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME.") + @Timeout(value = 5, unit = TimeUnit.MINUTES) @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @MethodSource("getTestParameters") public void fabricIqAsyncSample(HttpClient httpClient, AgentsServiceVersion serviceVersion) { - Assertions.fail("Enable after providing FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME."); + AgentsClientBuilder builder = getClientBuilder(httpClient, serviceVersion); + AgentsAsyncClient agentsAsyncClient = builder.buildAgentsAsyncClient(); + ResponsesAsyncClient responsesAsyncClient = builder.buildResponsesAsyncClient(); + + String agentName = testResourceNamer.randomName("fabric-iq-async-", 40); + AgentVersionDetails agent = null; + + try { + agent + = agentsAsyncClient.createAgentVersion(agentName, createAgentDefinition()).block(Duration.ofMinutes(2)); + Assertions.assertNotNull(agent); + + AgentReference agentReference = new AgentReference(agent.getName()).setVersion(agent.getVersion()); + Response response + = responsesAsyncClient + .createAzureResponse(new AzureCreateResponseOptions().setAgentReference(agentReference), + ResponseCreateParams.builder().input(getUserInput())) + .block(Duration.ofMinutes(3)); + + assertCompletedResponse(response); + } finally { + if (agent != null) { + agentsAsyncClient.deleteAgentVersion(agent.getName(), agent.getVersion()).block(Duration.ofMinutes(1)); + } + } + } + + private PromptAgentDefinition createAgentDefinition() { + FabricIQPreviewTool fabricIqTool + = new FabricIQPreviewTool(getRecordedConfig("FABRIC_IQ_PROJECT_CONNECTION_ID")).setServerLabel("fabric_iq") + .setRequireApproval("never") + .setName("fabric_iq_lookup") + .setDescription("Use FabricIQ to answer questions grounded in enterprise data."); + + return new PromptAgentDefinition(getRecordedConfig("FOUNDRY_MODEL_NAME")) + .setInstructions("Use the available Fabric IQ tools to answer questions and perform tasks.") + .setTools(Collections.singletonList(fabricIqTool)); + } + + private String getRecordedConfig(String name) { + if (getTestMode() == TestMode.PLAYBACK) { + return testResourceNamer.recordValueFromConfig(name); + } + + String value = Configuration.getGlobalConfiguration().get(name); + if (getTestMode() == TestMode.RECORD) { + testResourceNamer.recordValueFromConfig(name); + } + return value; + } + + private static String getUserInput() { + return Configuration.getGlobalConfiguration().get("FABRIC_IQ_USER_INPUT", DEFAULT_USER_INPUT); + } + + private static void assertCompletedResponse(Response response) { + Assertions.assertNotNull(response); + Assertions.assertTrue(response.status().isPresent()); + Assertions.assertEquals(ResponseStatus.COMPLETED, response.status().get()); + Assertions.assertFalse(response.output().isEmpty()); + Assertions.assertTrue(response.output().stream().anyMatch(item -> item.isMessage())); } } diff --git a/sdk/ai/azure-ai-projects/assets.json b/sdk/ai/azure-ai-projects/assets.json index dcc3df11c8e3..4118e9c761eb 100644 --- a/sdk/ai/azure-ai-projects/assets.json +++ b/sdk/ai/azure-ai-projects/assets.json @@ -1 +1 @@ -{"AssetsRepo":"Azure/azure-sdk-assets","AssetsRepoPrefixPath":"java","TagPrefix":"java/ai/azure-ai-projects","Tag": "java/ai/azure-ai-projects_cd9e4bffa9"} \ No newline at end of file +{"AssetsRepo":"Azure/azure-sdk-assets","AssetsRepoPrefixPath":"java","TagPrefix":"java/ai/azure-ai-projects","Tag": "java/ai/azure-ai-projects_8ce9469bb9"} \ No newline at end of file diff --git a/sdk/ai/azure-ai-projects/src/samples/java/com/azure/ai/projects/DataGenerationJobWithEvaluationSample.java b/sdk/ai/azure-ai-projects/src/samples/java/com/azure/ai/projects/DataGenerationJobWithEvaluationSample.java new file mode 100644 index 000000000000..4cde5e67823d --- /dev/null +++ b/sdk/ai/azure-ai-projects/src/samples/java/com/azure/ai/projects/DataGenerationJobWithEvaluationSample.java @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.ai.projects; + +import com.azure.ai.projects.models.ApiError; +import com.azure.ai.projects.models.DataGenerationJob; +import com.azure.ai.projects.models.DataGenerationJobInputs; +import com.azure.ai.projects.models.DataGenerationJobOutput; +import com.azure.ai.projects.models.DataGenerationJobOutputOptions; +import com.azure.ai.projects.models.DataGenerationJobScenario; +import com.azure.ai.projects.models.DataGenerationModelOptions; +import com.azure.ai.projects.models.DatasetDataGenerationJobOutput; +import com.azure.ai.projects.models.DatasetVersion; +import com.azure.ai.projects.models.FoundryFeaturesOptInKeys; +import com.azure.ai.projects.models.JobStatus; +import com.azure.ai.projects.models.PromptDataGenerationJobSource; +import com.azure.ai.projects.models.SimpleQnADataGenerationJobOptions; +import com.azure.core.util.Configuration; +import com.azure.identity.DefaultAzureCredentialBuilder; +import com.openai.client.OpenAIClient; +import com.openai.core.JsonField; +import com.openai.core.JsonValue; +import com.openai.models.evals.EvalCreateParams; +import com.openai.models.evals.EvalCreateParams.DataSourceConfig.Custom; +import com.openai.models.evals.EvalCreateParams.DataSourceConfig.Custom.ItemSchema; +import com.openai.models.evals.EvalCreateParams.TestingCriterion; +import com.openai.models.evals.EvalCreateResponse; +import com.openai.models.evals.EvalDeleteParams; +import com.openai.models.evals.runs.CreateEvalCompletionsRunDataSource; +import com.openai.models.evals.runs.RunCreateParams; +import com.openai.models.evals.runs.RunCreateResponse; +import com.openai.models.evals.runs.RunRetrieveParams; +import com.openai.models.evals.runs.RunRetrieveResponse; +import com.openai.models.evals.runs.outputitems.OutputItemListParams; +import com.openai.models.evals.runs.outputitems.OutputItemListResponse; +import com.openai.models.responses.EasyInputMessage; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * End-to-end sample combining data generation with an OpenAI evaluation run. + * + *

    The sample creates a Simple QnA data generation job from an inline prompt, waits for the generated dataset, + * creates an OpenAI evaluation using Azure AI built-in evaluators, runs the evaluation against the dataset, and cleans + * up the evaluation and data generation job.

    + * + *

    Before running the sample, set these environment variables:

    + *
      + *
    • {@code FOUNDRY_PROJECT_ENDPOINT} - the Azure AI Foundry project endpoint.
    • + *
    • {@code FOUNDRY_MODEL_NAME} - the model deployment name used for generation and judging.
    • + *
    • {@code DATASET_NAME} - optional, the generated dataset name.
    • + *
    • {@code POLL_INTERVAL_SECONDS} - optional, seconds to wait between polling attempts.
    • + *
    + */ +public class DataGenerationJobWithEvaluationSample { + private static final FoundryFeaturesOptInKeys DATA_GENERATION_PREVIEW + = FoundryFeaturesOptInKeys.DATA_GENERATION_JOBS_V1_PREVIEW; + private static final String DEFAULT_DATASET_NAME = "dataset-generation-eval-sample"; + private static final int DEFAULT_POLL_INTERVAL_SECONDS = 10; + + public static void main(String[] args) throws InterruptedException { + String endpoint = Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT"); + String modelName = Configuration.getGlobalConfiguration().get("FOUNDRY_MODEL_NAME"); + String datasetName = Configuration.getGlobalConfiguration().get("DATASET_NAME", DEFAULT_DATASET_NAME); + int pollIntervalSeconds = Integer.parseInt(Configuration.getGlobalConfiguration() + .get("POLL_INTERVAL_SECONDS", String.valueOf(DEFAULT_POLL_INTERVAL_SECONDS))); + + AIProjectClientBuilder projectClientBuilder = new AIProjectClientBuilder() + .endpoint(endpoint) + .credential(new DefaultAzureCredentialBuilder().build()); + + DataGenerationJobsClient dataGenerationJobsClient = projectClientBuilder.buildDataGenerationJobsClient(); + DatasetsClient datasetsClient = projectClientBuilder.buildDatasetsClient(); + OpenAIClient openAIClient = projectClientBuilder.buildOpenAIClient(); + + DataGenerationJob job = null; + EvalCreateResponse eval = null; + + try { + System.out.println("Create a data generation job."); + job = dataGenerationJobsClient.createGenerationJob(createDataGenerationJob(modelName, datasetName), + DATA_GENERATION_PREVIEW, UUID.randomUUID().toString()); + System.out.printf("Created data generation job `%s` (status: `%s`).%n", job.getId(), job.getStatus()); + + job = waitForDataGenerationJob(dataGenerationJobsClient, job.getId(), pollIntervalSeconds); + System.out.printf("Final job status: `%s`.%n", job.getStatus()); + + if (!JobStatus.SUCCEEDED.equals(job.getStatus())) { + ApiError error = job.getError(); + String message = error == null ? "" : error.getMessage(); + throw new IllegalStateException(String.format("Job `%s` ended with status `%s`: %s", + job.getId(), job.getStatus(), message)); + } + + DatasetDataGenerationJobOutput output = findDatasetOutput(job); + DatasetVersion dataset = datasetsClient.getDatasetVersion(output.getName(), output.getVersion()); + System.out.printf("Generated dataset: name=`%s` version=`%s` id=`%s`%n", + dataset.getName(), dataset.getVersion(), dataset.getId()); + + System.out.println("Create the evaluation."); + eval = openAIClient.evals().create(createEvaluationParams(modelName)); + System.out.printf("Evaluation created (id: %s).%n", eval.id()); + + System.out.printf("Create an evaluation run that consumes dataset `%s`.%n", dataset.getId()); + RunCreateResponse evalRun = openAIClient.evals().runs().create(createEvaluationRunParams(eval.id(), + dataset.getId(), modelName)); + System.out.printf("Evaluation run created (id: %s).%n", evalRun.id()); + + RunRetrieveResponse completedRun = waitForEvaluationRun(openAIClient, eval.id(), evalRun.id(), + pollIntervalSeconds); + System.out.printf("Final eval run status: `%s`.%n", completedRun.status()); + + if ("completed".equals(completedRun.status())) { + System.out.printf("Result counts: %s%n", completedRun.resultCounts()); + System.out.printf("Eval run report URL: %s%n", completedRun.reportUrl()); + printOutputItems(openAIClient, eval.id(), evalRun.id()); + } else { + System.out.println("Evaluation run did not complete successfully."); + } + } finally { + if (eval != null) { + System.out.printf("Delete evaluation `%s`.%n", eval.id()); + openAIClient.evals().delete(EvalDeleteParams.builder().evalId(eval.id()).build()); + } + if (job != null) { + System.out.printf("Delete the data generation job `%s`.%n", job.getId()); + dataGenerationJobsClient.deleteGenerationJob(job.getId(), DATA_GENERATION_PREVIEW); + } + } + } + + static DataGenerationJob createDataGenerationJob(String modelName, String datasetName) { + PromptDataGenerationJobSource source = new PromptDataGenerationJobSource( + "Contoso offers a full refund within 30 days of purchase for any product returned in its original " + + "condition. After 30 days, store credit may be issued at the discretion of customer support. " + + "Digital goods are non-refundable once downloaded.") + .setDescription("Contoso refund policy"); + + SimpleQnADataGenerationJobOptions options = new SimpleQnADataGenerationJobOptions(15) + .setModelOptions(new DataGenerationModelOptions(modelName)); + + DataGenerationJobOutputOptions outputOptions = new DataGenerationJobOutputOptions() + .setName(datasetName) + .setDescription("QnA pairs generated from the Contoso refund policy prompt.") + .setTags(Collections.singletonMap("sample", "dataset-generation-with-evaluation")); + + DataGenerationJobInputs inputs = new DataGenerationJobInputs("qna-from-policy-prompt", + Collections.singletonList(source), options, DataGenerationJobScenario.EVALUATION) + .setOutputOptions(outputOptions); + + return new DataGenerationJob().setInputs(inputs); + } + + private static DataGenerationJob waitForDataGenerationJob(DataGenerationJobsClient dataGenerationJobsClient, + String jobId, int pollIntervalSeconds) throws InterruptedException { + System.out.printf("Poll job `%s` until it reaches a terminal state.", jobId); + DataGenerationJob job; + do { + Thread.sleep(pollIntervalSeconds * 1000L); + System.out.print("."); + job = dataGenerationJobsClient.getGenerationJob(jobId, DATA_GENERATION_PREVIEW); + } while (!isTerminalStatus(job.getStatus())); + System.out.println(); + return job; + } + + static boolean isTerminalStatus(JobStatus status) { + return JobStatus.SUCCEEDED.equals(status) + || JobStatus.FAILED.equals(status) + || JobStatus.CANCELLED.equals(status); + } + + static DatasetDataGenerationJobOutput findDatasetOutput(DataGenerationJob job) { + if (job.getResult() != null && job.getResult().getOutputs() != null) { + for (DataGenerationJobOutput output : job.getResult().getOutputs()) { + if (output instanceof DatasetDataGenerationJobOutput) { + DatasetDataGenerationJobOutput datasetOutput = (DatasetDataGenerationJobOutput) output; + if (datasetOutput.getName() != null && datasetOutput.getVersion() != null) { + return datasetOutput; + } + } + } + } + + throw new IllegalStateException(String.format("Job `%s` did not produce a dataset output.", job.getId())); + } + + static EvalCreateParams createEvaluationParams(String modelName) { + Map queryProperty = new LinkedHashMap<>(); + queryProperty.put("type", "string"); + + Map groundTruthProperty = new LinkedHashMap<>(); + groundTruthProperty.put("type", "string"); + + Map properties = new LinkedHashMap<>(); + properties.put("query", queryProperty); + properties.put("ground_truth", groundTruthProperty); + + ItemSchema itemSchema = ItemSchema.builder() + .putAdditionalProperty("type", JsonValue.from("object")) + .putAdditionalProperty("properties", JsonValue.from(properties)) + .putAdditionalProperty("required", JsonValue.from(Collections.singletonList("query"))) + .build(); + + Custom dataSourceConfig = Custom.builder() + .itemSchema(itemSchema) + .includeSampleSchema(true) + .build(); + + return EvalCreateParams.builder() + .name("generated-qna-evaluation") + .dataSourceConfig(dataSourceConfig) + .testingCriteria(createAzureAIEvaluatorCriteria(modelName)) + .build(); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + private static JsonField> createAzureAIEvaluatorCriteria(String modelName) { + // openai-java 4.14.0 does not have a typed Azure AI evaluator testing criterion model. + return (JsonField) JsonValue.from(Arrays.asList( + createAzureAIEvaluator("coherence", "builtin.coherence", modelName, + mapOf("query", "{{item.query}}", "response", "{{sample.output_text}}")), + createAzureAIEvaluator("fluency", "builtin.fluency", modelName, + Collections.singletonMap("response", "{{sample.output_text}}")))); + } + + private static Map createAzureAIEvaluator(String name, String evaluatorName, String modelName, + Map dataMapping) { + Map initializationParameters = new LinkedHashMap<>(); + initializationParameters.put("deployment_name", modelName); + + Map evaluator = new LinkedHashMap<>(); + evaluator.put("type", "azure_ai_evaluator"); + evaluator.put("name", name); + evaluator.put("evaluator_name", evaluatorName); + evaluator.put("initialization_parameters", initializationParameters); + evaluator.put("data_mapping", dataMapping); + return evaluator; + } + + static RunCreateParams createEvaluationRunParams(String evalId, String datasetId, String modelName) { + CreateEvalCompletionsRunDataSource.InputMessages.Template inputMessages + = CreateEvalCompletionsRunDataSource.InputMessages.Template.builder() + .addTemplate(EasyInputMessage.builder() + .role(EasyInputMessage.Role.DEVELOPER) + .content("You are a Contoso customer-support assistant. Answer the user's question about the " + + "Contoso refund policy clearly and concisely.") + .build()) + .addTemplate(EasyInputMessage.builder() + .role(EasyInputMessage.Role.USER) + .content("{{item.query}}") + .build()) + .build(); + + CreateEvalCompletionsRunDataSource dataSource = CreateEvalCompletionsRunDataSource.builder() + .fileIdSource(datasetId) + .type(CreateEvalCompletionsRunDataSource.Type.COMPLETIONS) + .inputMessages(inputMessages) + .model(modelName) + .build(); + + return RunCreateParams.builder() + .evalId(evalId) + .name("generated-qna-evaluation-run") + .dataSource(dataSource) + .build(); + } + + private static RunRetrieveResponse waitForEvaluationRun(OpenAIClient openAIClient, String evalId, String runId, + int pollIntervalSeconds) throws InterruptedException { + RunRetrieveResponse evalRun = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder() + .evalId(evalId) + .runId(runId) + .build()); + while (!"completed".equals(evalRun.status()) && !"failed".equals(evalRun.status())) { + Thread.sleep(pollIntervalSeconds * 1000L); + evalRun = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder() + .evalId(evalId) + .runId(runId) + .build()); + } + return evalRun; + } + + private static void printOutputItems(OpenAIClient openAIClient, String evalId, String runId) { + int count = 0; + for (OutputItemListResponse item : openAIClient.evals().runs().outputItems().list(OutputItemListParams.builder() + .evalId(evalId) + .runId(runId) + .build()).autoPager()) { + count++; + System.out.printf(" item %d: status=%s | %s%n", count, item.status(), item.results()); + } + System.out.printf("Output items (total: %d).%n", count); + } + + private static Map mapOf(String firstKey, String firstValue, String secondKey, String secondValue) { + Map map = new LinkedHashMap<>(); + map.put(firstKey, firstValue); + map.put(secondKey, secondValue); + return map; + } +} diff --git a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java index e8f526ddc5a6..551b40ea0028 100644 --- a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java +++ b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java @@ -55,10 +55,10 @@ protected AIProjectClientBuilder getClientBuilder(HttpClient httpClient, builder.endpoint("https://localhost:8080").credential(new MockTokenCredential()); } else if (testMode == TestMode.RECORD) { builder.addPolicy(interceptorManager.getRecordPolicy()) - .endpoint(Configuration.getGlobalConfiguration().get("AI_PROJECTS_ENDPOINT")) + .endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT")) .credential(new DefaultAzureCredentialBuilder().build()); } else { - builder.endpoint(Configuration.getGlobalConfiguration().get("AI_PROJECTS_ENDPOINT")) + builder.endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT")) .credential(new DefaultAzureCredentialBuilder().build()); } diff --git a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java index 25da419c09db..0a6c8f2ed104 100644 --- a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java +++ b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java @@ -3,16 +3,31 @@ package com.azure.ai.projects; import com.azure.ai.agents.models.PageOrder; +import com.azure.ai.projects.models.ApiError; import com.azure.ai.projects.models.DataGenerationJob; +import com.azure.ai.projects.models.DatasetDataGenerationJobOutput; +import com.azure.ai.projects.models.DatasetVersion; import com.azure.ai.projects.models.FoundryFeaturesOptInKeys; +import com.azure.ai.projects.models.JobStatus; import com.azure.ai.projects.models.ModelVersion; import com.azure.ai.projects.models.SkillDetails; import com.azure.core.exception.ResourceNotFoundException; import com.azure.core.http.HttpClient; +import com.azure.core.test.TestMode; import com.azure.core.test.annotation.LiveOnly; import com.azure.core.util.BinaryData; +import com.azure.core.util.Configuration; +import com.openai.client.OpenAIClient; +import com.openai.models.evals.EvalCreateResponse; +import com.openai.models.evals.EvalDeleteParams; +import com.openai.models.evals.runs.RunCreateResponse; +import com.openai.models.evals.runs.RunRetrieveParams; +import com.openai.models.evals.runs.RunRetrieveResponse; +import com.openai.models.evals.runs.outputitems.OutputItemListParams; +import com.openai.models.evals.runs.outputitems.OutputItemListResponse; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import reactor.test.StepVerifier; @@ -22,6 +37,7 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; @@ -184,6 +200,72 @@ public void dataGenerationCreateGetCancelDeleteSample(HttpClient httpClient, "Enable after providing FOUNDRY_MODEL_NAME and deciding whether to record this long-running preview flow."); } + @Timeout(value = 20, unit = TimeUnit.MINUTES) + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.projects.TestUtils#getTestParameters") + public void dataGenerationJobWithEvaluationSample(HttpClient httpClient, AIProjectsServiceVersion serviceVersion) + throws InterruptedException { + AIProjectClientBuilder projectClientBuilder = getClientBuilder(httpClient, serviceVersion); + DataGenerationJobsClient dataGenerationJobsClient = projectClientBuilder.buildDataGenerationJobsClient(); + DatasetsClient datasetsClient = projectClientBuilder.buildDatasetsClient(); + OpenAIClient openAIClient = projectClientBuilder.buildOpenAIClient(); + + String modelName = getRecordedConfig("FOUNDRY_MODEL_NAME"); + String datasetName = testResourceNamer.randomName("dataset-generation-eval-", 64); + + DataGenerationJob job = null; + EvalCreateResponse eval = null; + + try { + job = dataGenerationJobsClient.createGenerationJob( + DataGenerationJobWithEvaluationSample.createDataGenerationJob(modelName, datasetName), + DATA_GENERATION_PREVIEW, testResourceNamer.randomUuid()); + + job = waitForDataGenerationJob(dataGenerationJobsClient, job.getId(), 5, 180); + if (!JobStatus.SUCCEEDED.equals(job.getStatus())) { + ApiError error = job.getError(); + String message = error == null ? "" : error.getMessage(); + Assertions + .fail(String.format("Job `%s` ended with status `%s`: %s", job.getId(), job.getStatus(), message)); + } + + DatasetDataGenerationJobOutput output = DataGenerationJobWithEvaluationSample.findDatasetOutput(job); + DatasetVersion dataset = datasetsClient.getDatasetVersion(output.getName(), output.getVersion()); + Assertions.assertNotNull(dataset); + Assertions.assertNotNull(dataset.getId()); + + eval = openAIClient.evals().create(DataGenerationJobWithEvaluationSample.createEvaluationParams(modelName)); + Assertions.assertNotNull(eval); + + RunCreateResponse evalRun = openAIClient.evals() + .runs() + .create(DataGenerationJobWithEvaluationSample.createEvaluationRunParams(eval.id(), dataset.getId(), + modelName)); + Assertions.assertNotNull(evalRun); + + RunRetrieveResponse completedRun = waitForEvaluationRun(openAIClient, eval.id(), evalRun.id(), 5, 180); + Assertions.assertEquals("completed", completedRun.status()); + Assertions.assertNotNull(completedRun.resultCounts()); + + int outputItemCount = 0; + for (OutputItemListResponse ignored : openAIClient.evals() + .runs() + .outputItems() + .list(OutputItemListParams.builder().evalId(eval.id()).runId(evalRun.id()).build()) + .autoPager()) { + outputItemCount++; + } + Assertions.assertTrue(outputItemCount > 0); + } finally { + if (eval != null) { + openAIClient.evals().delete(EvalDeleteParams.builder().evalId(eval.id()).build()); + } + if (job != null) { + dataGenerationJobsClient.deleteGenerationJob(job.getId(), DATA_GENERATION_PREVIEW); + } + } + } + @Disabled("Requires FOUNDRY_MODEL_ASSET_NAME, FOUNDRY_MODEL_ASSET_VERSION, and FOUNDRY_MODEL_BLOB_URI.") @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @MethodSource("com.azure.ai.projects.TestUtils#getTestParameters") @@ -206,4 +288,42 @@ private static BinaryData createSkillPackage() throws IOException { return BinaryData.fromBytes(outputStream.toByteArray()); } + + private DataGenerationJob waitForDataGenerationJob(DataGenerationJobsClient dataGenerationJobsClient, String jobId, + int pollIntervalSeconds, int maxAttempts) throws InterruptedException { + DataGenerationJob job; + int attempts = 0; + do { + sleepIfRunningAgainstService(pollIntervalSeconds * 1000L); + job = dataGenerationJobsClient.getGenerationJob(jobId, DATA_GENERATION_PREVIEW); + attempts++; + } while (!DataGenerationJobWithEvaluationSample.isTerminalStatus(job.getStatus()) && attempts < maxAttempts); + return job; + } + + private RunRetrieveResponse waitForEvaluationRun(OpenAIClient openAIClient, String evalId, String runId, + int pollIntervalSeconds, int maxAttempts) throws InterruptedException { + RunRetrieveResponse evalRun + = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder().evalId(evalId).runId(runId).build()); + int attempts = 0; + while (!"completed".equals(evalRun.status()) && !"failed".equals(evalRun.status()) && attempts < maxAttempts) { + sleepIfRunningAgainstService(pollIntervalSeconds * 1000L); + evalRun + = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder().evalId(evalId).runId(runId).build()); + attempts++; + } + return evalRun; + } + + private String getRecordedConfig(String name) { + if (getTestMode() == TestMode.PLAYBACK) { + return testResourceNamer.recordValueFromConfig(name); + } + + String value = Configuration.getGlobalConfiguration().get(name); + if (getTestMode() == TestMode.RECORD) { + testResourceNamer.recordValueFromConfig(name); + } + return value; + } }