diff --git a/sdk/ai/azure-ai-agents/assets.json b/sdk/ai/azure-ai-agents/assets.json
index bf21adf5156b..bbeb80ce0ad6 100644
--- a/sdk/ai/azure-ai-agents/assets.json
+++ b/sdk/ai/azure-ai-agents/assets.json
@@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "java",
"TagPrefix": "java/ai/azure-ai-agents",
- "Tag": "java/ai/azure-ai-agents_5238437e7d"
+ "Tag": "java/ai/azure-ai-agents_767bd42da6"
}
diff --git a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java
index 5b4c67bc5cb2..e22571576b8f 100644
--- a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java
+++ b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQAsync.java
@@ -28,6 +28,7 @@
*
FOUNDRY_PROJECT_ENDPOINT - The Azure AI Project endpoint.
* FOUNDRY_MODEL_NAME - The model deployment name.
* FABRIC_IQ_PROJECT_CONNECTION_ID - The FabricIQ connection ID.
+ * FABRIC_IQ_USER_INPUT - Optional. The natural-language question to send to the agent.
*
*/
public class FabricIQAsync {
@@ -35,6 +36,8 @@ public static void main(String[] args) {
String endpoint = Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT");
String model = Configuration.getGlobalConfiguration().get("FOUNDRY_MODEL_NAME");
String fabricIqConnectionId = Configuration.getGlobalConfiguration().get("FABRIC_IQ_PROJECT_CONNECTION_ID");
+ String userInput = Configuration.getGlobalConfiguration().get("FABRIC_IQ_USER_INPUT",
+ "Use FabricIQ to summarize the available enterprise context.");
AgentsClientBuilder builder = new AgentsClientBuilder()
.credential(new DefaultAzureCredentialBuilder().build())
@@ -51,10 +54,10 @@ public static void main(String[] args) {
.setDescription("Use FabricIQ to answer questions grounded in enterprise data.");
PromptAgentDefinition agentDefinition = new PromptAgentDefinition(model)
- .setInstructions("You are a data assistant that can use FabricIQ for grounded enterprise answers.")
+ .setInstructions("Use the available Fabric IQ tools to answer questions and perform tasks.")
.setTools(Collections.singletonList(fabricIqTool));
- agentsAsyncClient.createAgentVersion("fabric-iq-async-agent", agentDefinition)
+ agentsAsyncClient.createAgentVersion("MyAgent", agentDefinition)
.flatMap(agent -> {
agentRef.set(agent);
System.out.printf("Agent created: %s (version %s)%n", agent.getName(), agent.getVersion());
@@ -65,7 +68,7 @@ public static void main(String[] args) {
return responsesAsyncClient.createAzureResponse(
new AzureCreateResponseOptions().setAgentReference(agentReference),
ResponseCreateParams.builder()
- .input("Use FabricIQ to summarize the available enterprise context."));
+ .input(userInput));
})
.doOnNext(response -> System.out.println("Response: " + response.output()))
.then(Mono.defer(() -> {
diff --git a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java
index 34d7272b8c06..e0b397f4fab9 100644
--- a/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java
+++ b/sdk/ai/azure-ai-agents/src/samples/java/com/azure/ai/agents/tools/FabricIQSync.java
@@ -26,6 +26,7 @@
* FOUNDRY_PROJECT_ENDPOINT - The Azure AI Project endpoint.
* FOUNDRY_MODEL_NAME - The model deployment name.
* FABRIC_IQ_PROJECT_CONNECTION_ID - The FabricIQ connection ID.
+ * FABRIC_IQ_USER_INPUT - Optional. The natural-language question to send to the agent.
*
*/
public class FabricIQSync {
@@ -33,6 +34,8 @@ public static void main(String[] args) {
String endpoint = Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT");
String model = Configuration.getGlobalConfiguration().get("FOUNDRY_MODEL_NAME");
String fabricIqConnectionId = Configuration.getGlobalConfiguration().get("FABRIC_IQ_PROJECT_CONNECTION_ID");
+ String userInput = Configuration.getGlobalConfiguration().get("FABRIC_IQ_USER_INPUT",
+ "Use FabricIQ to summarize the available enterprise context.");
AgentsClientBuilder builder = new AgentsClientBuilder()
.credential(new DefaultAzureCredentialBuilder().build())
@@ -52,10 +55,10 @@ public static void main(String[] args) {
// END: com.azure.ai.agents.define_fabric_iq
PromptAgentDefinition agentDefinition = new PromptAgentDefinition(model)
- .setInstructions("You are a data assistant that can use FabricIQ for grounded enterprise answers.")
+ .setInstructions("Use the available Fabric IQ tools to answer questions and perform tasks.")
.setTools(Collections.singletonList(fabricIqTool));
- AgentVersionDetails agent = agentsClient.createAgentVersion("fabric-iq-agent", agentDefinition);
+ AgentVersionDetails agent = agentsClient.createAgentVersion("MyAgent", agentDefinition);
System.out.printf("Agent created: %s (version %s)%n", agent.getName(), agent.getVersion());
try {
@@ -65,7 +68,7 @@ public static void main(String[] args) {
Response response = responsesClient.createAzureResponse(
new AzureCreateResponseOptions().setAgentReference(agentReference),
ResponseCreateParams.builder()
- .input("Use FabricIQ to summarize the available enterprise context."));
+ .input(userInput));
System.out.println("Response: " + response.output());
} finally {
diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java
index 04ea90051026..097a3284ba8f 100644
--- a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java
+++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/ClientTestBase.java
@@ -42,10 +42,10 @@ protected AgentsClientBuilder getClientBuilder(HttpClient httpClient, AgentsServ
builder.endpoint("https://localhost:8080").credential(new MockTokenCredential());
} else if (testMode == TestMode.RECORD) {
builder.addPolicy(interceptorManager.getRecordPolicy())
- .endpoint(Configuration.getGlobalConfiguration().get("AZURE_AGENTS_ENDPOINT"))
+ .endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT"))
.credential(new DefaultAzureCredentialBuilder().build());
} else {
- builder.endpoint(Configuration.getGlobalConfiguration().get("AZURE_AGENTS_ENDPOINT"))
+ builder.endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT"))
.credential(new DefaultAzureCredentialBuilder().build());
}
diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java
index eda638e6feaf..77c421fc1204 100644
--- a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java
+++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/tools/FabricIQSamplesTests.java
@@ -3,23 +3,42 @@
package com.azure.ai.agents.tools;
+import com.azure.ai.agents.AgentsAsyncClient;
+import com.azure.ai.agents.AgentsClient;
+import com.azure.ai.agents.AgentsClientBuilder;
import com.azure.ai.agents.AgentsServiceVersion;
import com.azure.ai.agents.ClientTestBase;
+import com.azure.ai.agents.ResponsesAsyncClient;
+import com.azure.ai.agents.ResponsesClient;
import com.azure.core.http.HttpClient;
-import org.junit.jupiter.api.Disabled;
+import com.azure.core.test.TestMode;
+import com.azure.core.util.Configuration;
+import com.azure.ai.agents.models.AgentReference;
+import com.azure.ai.agents.models.AgentVersionDetails;
+import com.azure.ai.agents.models.AzureCreateResponseOptions;
+import com.azure.ai.agents.models.FabricIQPreviewTool;
+import com.azure.ai.agents.models.PromptAgentDefinition;
+import com.openai.models.responses.Response;
+import com.openai.models.responses.ResponseCreateParams;
+import com.openai.models.responses.ResponseStatus;
import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
+import java.time.Duration;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
+import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import static com.azure.core.test.TestProxyTestBase.getHttpClients;
public class FabricIQSamplesTests extends ClientTestBase {
private static final String DISPLAY_NAME_WITH_ARGUMENTS = "{displayName} with [{arguments}]";
+ private static final String DEFAULT_USER_INPUT = "Show weather events in Texas.";
static Stream getTestParameters() {
List argumentsList = new ArrayList<>();
@@ -27,17 +46,98 @@ static Stream getTestParameters() {
return argumentsList.stream();
}
- @Disabled("Requires FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME.")
+ @Timeout(value = 5, unit = TimeUnit.MINUTES)
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("getTestParameters")
public void fabricIqSyncSample(HttpClient httpClient, AgentsServiceVersion serviceVersion) {
- Assertions.fail("Enable after providing FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME.");
+ AgentsClientBuilder builder = getClientBuilder(httpClient, serviceVersion);
+ AgentsClient agentsClient = builder.buildAgentsClient();
+ ResponsesClient responsesClient = builder.buildResponsesClient();
+
+ String agentName = testResourceNamer.randomName("fabric-iq-sync-", 40);
+ AgentVersionDetails agent = null;
+
+ try {
+ agent = agentsClient.createAgentVersion(agentName, createAgentDefinition());
+ Assertions.assertNotNull(agent);
+
+ AgentReference agentReference = new AgentReference(agent.getName()).setVersion(agent.getVersion());
+ Response response = responsesClient.createAzureResponse(
+ new AzureCreateResponseOptions().setAgentReference(agentReference),
+ ResponseCreateParams.builder().input(getUserInput()));
+
+ assertCompletedResponse(response);
+ } finally {
+ if (agent != null) {
+ agentsClient.deleteAgentVersion(agent.getName(), agent.getVersion());
+ }
+ }
}
- @Disabled("Requires FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME.")
+ @Timeout(value = 5, unit = TimeUnit.MINUTES)
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("getTestParameters")
public void fabricIqAsyncSample(HttpClient httpClient, AgentsServiceVersion serviceVersion) {
- Assertions.fail("Enable after providing FABRIC_IQ_PROJECT_CONNECTION_ID and FOUNDRY_MODEL_NAME.");
+ AgentsClientBuilder builder = getClientBuilder(httpClient, serviceVersion);
+ AgentsAsyncClient agentsAsyncClient = builder.buildAgentsAsyncClient();
+ ResponsesAsyncClient responsesAsyncClient = builder.buildResponsesAsyncClient();
+
+ String agentName = testResourceNamer.randomName("fabric-iq-async-", 40);
+ AgentVersionDetails agent = null;
+
+ try {
+ agent
+ = agentsAsyncClient.createAgentVersion(agentName, createAgentDefinition()).block(Duration.ofMinutes(2));
+ Assertions.assertNotNull(agent);
+
+ AgentReference agentReference = new AgentReference(agent.getName()).setVersion(agent.getVersion());
+ Response response
+ = responsesAsyncClient
+ .createAzureResponse(new AzureCreateResponseOptions().setAgentReference(agentReference),
+ ResponseCreateParams.builder().input(getUserInput()))
+ .block(Duration.ofMinutes(3));
+
+ assertCompletedResponse(response);
+ } finally {
+ if (agent != null) {
+ agentsAsyncClient.deleteAgentVersion(agent.getName(), agent.getVersion()).block(Duration.ofMinutes(1));
+ }
+ }
+ }
+
+ private PromptAgentDefinition createAgentDefinition() {
+ FabricIQPreviewTool fabricIqTool
+ = new FabricIQPreviewTool(getRecordedConfig("FABRIC_IQ_PROJECT_CONNECTION_ID")).setServerLabel("fabric_iq")
+ .setRequireApproval("never")
+ .setName("fabric_iq_lookup")
+ .setDescription("Use FabricIQ to answer questions grounded in enterprise data.");
+
+ return new PromptAgentDefinition(getRecordedConfig("FOUNDRY_MODEL_NAME"))
+ .setInstructions("Use the available Fabric IQ tools to answer questions and perform tasks.")
+ .setTools(Collections.singletonList(fabricIqTool));
+ }
+
+ private String getRecordedConfig(String name) {
+ if (getTestMode() == TestMode.PLAYBACK) {
+ return testResourceNamer.recordValueFromConfig(name);
+ }
+
+ String value = Configuration.getGlobalConfiguration().get(name);
+ if (getTestMode() == TestMode.RECORD) {
+ testResourceNamer.recordValueFromConfig(name);
+ }
+ return value;
+ }
+
+ private static String getUserInput() {
+ return Configuration.getGlobalConfiguration().get("FABRIC_IQ_USER_INPUT", DEFAULT_USER_INPUT);
+ }
+
+ private static void assertCompletedResponse(Response response) {
+ Assertions.assertNotNull(response);
+ Assertions.assertTrue(response.status().isPresent());
+ Assertions.assertEquals(ResponseStatus.COMPLETED, response.status().get());
+ Assertions.assertFalse(response.output().isEmpty());
+ Assertions.assertTrue(response.output().stream().anyMatch(item -> item.isMessage()));
}
}
diff --git a/sdk/ai/azure-ai-projects/assets.json b/sdk/ai/azure-ai-projects/assets.json
index dcc3df11c8e3..4118e9c761eb 100644
--- a/sdk/ai/azure-ai-projects/assets.json
+++ b/sdk/ai/azure-ai-projects/assets.json
@@ -1 +1 @@
-{"AssetsRepo":"Azure/azure-sdk-assets","AssetsRepoPrefixPath":"java","TagPrefix":"java/ai/azure-ai-projects","Tag": "java/ai/azure-ai-projects_cd9e4bffa9"}
\ No newline at end of file
+{"AssetsRepo":"Azure/azure-sdk-assets","AssetsRepoPrefixPath":"java","TagPrefix":"java/ai/azure-ai-projects","Tag": "java/ai/azure-ai-projects_8ce9469bb9"}
\ No newline at end of file
diff --git a/sdk/ai/azure-ai-projects/src/samples/java/com/azure/ai/projects/DataGenerationJobWithEvaluationSample.java b/sdk/ai/azure-ai-projects/src/samples/java/com/azure/ai/projects/DataGenerationJobWithEvaluationSample.java
new file mode 100644
index 000000000000..4cde5e67823d
--- /dev/null
+++ b/sdk/ai/azure-ai-projects/src/samples/java/com/azure/ai/projects/DataGenerationJobWithEvaluationSample.java
@@ -0,0 +1,308 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.ai.projects;
+
+import com.azure.ai.projects.models.ApiError;
+import com.azure.ai.projects.models.DataGenerationJob;
+import com.azure.ai.projects.models.DataGenerationJobInputs;
+import com.azure.ai.projects.models.DataGenerationJobOutput;
+import com.azure.ai.projects.models.DataGenerationJobOutputOptions;
+import com.azure.ai.projects.models.DataGenerationJobScenario;
+import com.azure.ai.projects.models.DataGenerationModelOptions;
+import com.azure.ai.projects.models.DatasetDataGenerationJobOutput;
+import com.azure.ai.projects.models.DatasetVersion;
+import com.azure.ai.projects.models.FoundryFeaturesOptInKeys;
+import com.azure.ai.projects.models.JobStatus;
+import com.azure.ai.projects.models.PromptDataGenerationJobSource;
+import com.azure.ai.projects.models.SimpleQnADataGenerationJobOptions;
+import com.azure.core.util.Configuration;
+import com.azure.identity.DefaultAzureCredentialBuilder;
+import com.openai.client.OpenAIClient;
+import com.openai.core.JsonField;
+import com.openai.core.JsonValue;
+import com.openai.models.evals.EvalCreateParams;
+import com.openai.models.evals.EvalCreateParams.DataSourceConfig.Custom;
+import com.openai.models.evals.EvalCreateParams.DataSourceConfig.Custom.ItemSchema;
+import com.openai.models.evals.EvalCreateParams.TestingCriterion;
+import com.openai.models.evals.EvalCreateResponse;
+import com.openai.models.evals.EvalDeleteParams;
+import com.openai.models.evals.runs.CreateEvalCompletionsRunDataSource;
+import com.openai.models.evals.runs.RunCreateParams;
+import com.openai.models.evals.runs.RunCreateResponse;
+import com.openai.models.evals.runs.RunRetrieveParams;
+import com.openai.models.evals.runs.RunRetrieveResponse;
+import com.openai.models.evals.runs.outputitems.OutputItemListParams;
+import com.openai.models.evals.runs.outputitems.OutputItemListResponse;
+import com.openai.models.responses.EasyInputMessage;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * End-to-end sample combining data generation with an OpenAI evaluation run.
+ *
+ * The sample creates a Simple QnA data generation job from an inline prompt, waits for the generated dataset,
+ * creates an OpenAI evaluation using Azure AI built-in evaluators, runs the evaluation against the dataset, and cleans
+ * up the evaluation and data generation job.
+ *
+ * Before running the sample, set these environment variables:
+ *
+ * - {@code FOUNDRY_PROJECT_ENDPOINT} - the Azure AI Foundry project endpoint.
+ * - {@code FOUNDRY_MODEL_NAME} - the model deployment name used for generation and judging.
+ * - {@code DATASET_NAME} - optional, the generated dataset name.
+ * - {@code POLL_INTERVAL_SECONDS} - optional, seconds to wait between polling attempts.
+ *
+ */
+public class DataGenerationJobWithEvaluationSample {
+ private static final FoundryFeaturesOptInKeys DATA_GENERATION_PREVIEW
+ = FoundryFeaturesOptInKeys.DATA_GENERATION_JOBS_V1_PREVIEW;
+ private static final String DEFAULT_DATASET_NAME = "dataset-generation-eval-sample";
+ private static final int DEFAULT_POLL_INTERVAL_SECONDS = 10;
+
+ public static void main(String[] args) throws InterruptedException {
+ String endpoint = Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT");
+ String modelName = Configuration.getGlobalConfiguration().get("FOUNDRY_MODEL_NAME");
+ String datasetName = Configuration.getGlobalConfiguration().get("DATASET_NAME", DEFAULT_DATASET_NAME);
+ int pollIntervalSeconds = Integer.parseInt(Configuration.getGlobalConfiguration()
+ .get("POLL_INTERVAL_SECONDS", String.valueOf(DEFAULT_POLL_INTERVAL_SECONDS)));
+
+ AIProjectClientBuilder projectClientBuilder = new AIProjectClientBuilder()
+ .endpoint(endpoint)
+ .credential(new DefaultAzureCredentialBuilder().build());
+
+ DataGenerationJobsClient dataGenerationJobsClient = projectClientBuilder.buildDataGenerationJobsClient();
+ DatasetsClient datasetsClient = projectClientBuilder.buildDatasetsClient();
+ OpenAIClient openAIClient = projectClientBuilder.buildOpenAIClient();
+
+ DataGenerationJob job = null;
+ EvalCreateResponse eval = null;
+
+ try {
+ System.out.println("Create a data generation job.");
+ job = dataGenerationJobsClient.createGenerationJob(createDataGenerationJob(modelName, datasetName),
+ DATA_GENERATION_PREVIEW, UUID.randomUUID().toString());
+ System.out.printf("Created data generation job `%s` (status: `%s`).%n", job.getId(), job.getStatus());
+
+ job = waitForDataGenerationJob(dataGenerationJobsClient, job.getId(), pollIntervalSeconds);
+ System.out.printf("Final job status: `%s`.%n", job.getStatus());
+
+ if (!JobStatus.SUCCEEDED.equals(job.getStatus())) {
+ ApiError error = job.getError();
+ String message = error == null ? "" : error.getMessage();
+ throw new IllegalStateException(String.format("Job `%s` ended with status `%s`: %s",
+ job.getId(), job.getStatus(), message));
+ }
+
+ DatasetDataGenerationJobOutput output = findDatasetOutput(job);
+ DatasetVersion dataset = datasetsClient.getDatasetVersion(output.getName(), output.getVersion());
+ System.out.printf("Generated dataset: name=`%s` version=`%s` id=`%s`%n",
+ dataset.getName(), dataset.getVersion(), dataset.getId());
+
+ System.out.println("Create the evaluation.");
+ eval = openAIClient.evals().create(createEvaluationParams(modelName));
+ System.out.printf("Evaluation created (id: %s).%n", eval.id());
+
+ System.out.printf("Create an evaluation run that consumes dataset `%s`.%n", dataset.getId());
+ RunCreateResponse evalRun = openAIClient.evals().runs().create(createEvaluationRunParams(eval.id(),
+ dataset.getId(), modelName));
+ System.out.printf("Evaluation run created (id: %s).%n", evalRun.id());
+
+ RunRetrieveResponse completedRun = waitForEvaluationRun(openAIClient, eval.id(), evalRun.id(),
+ pollIntervalSeconds);
+ System.out.printf("Final eval run status: `%s`.%n", completedRun.status());
+
+ if ("completed".equals(completedRun.status())) {
+ System.out.printf("Result counts: %s%n", completedRun.resultCounts());
+ System.out.printf("Eval run report URL: %s%n", completedRun.reportUrl());
+ printOutputItems(openAIClient, eval.id(), evalRun.id());
+ } else {
+ System.out.println("Evaluation run did not complete successfully.");
+ }
+ } finally {
+ if (eval != null) {
+ System.out.printf("Delete evaluation `%s`.%n", eval.id());
+ openAIClient.evals().delete(EvalDeleteParams.builder().evalId(eval.id()).build());
+ }
+ if (job != null) {
+ System.out.printf("Delete the data generation job `%s`.%n", job.getId());
+ dataGenerationJobsClient.deleteGenerationJob(job.getId(), DATA_GENERATION_PREVIEW);
+ }
+ }
+ }
+
+ static DataGenerationJob createDataGenerationJob(String modelName, String datasetName) {
+ PromptDataGenerationJobSource source = new PromptDataGenerationJobSource(
+ "Contoso offers a full refund within 30 days of purchase for any product returned in its original "
+ + "condition. After 30 days, store credit may be issued at the discretion of customer support. "
+ + "Digital goods are non-refundable once downloaded.")
+ .setDescription("Contoso refund policy");
+
+ SimpleQnADataGenerationJobOptions options = new SimpleQnADataGenerationJobOptions(15)
+ .setModelOptions(new DataGenerationModelOptions(modelName));
+
+ DataGenerationJobOutputOptions outputOptions = new DataGenerationJobOutputOptions()
+ .setName(datasetName)
+ .setDescription("QnA pairs generated from the Contoso refund policy prompt.")
+ .setTags(Collections.singletonMap("sample", "dataset-generation-with-evaluation"));
+
+ DataGenerationJobInputs inputs = new DataGenerationJobInputs("qna-from-policy-prompt",
+ Collections.singletonList(source), options, DataGenerationJobScenario.EVALUATION)
+ .setOutputOptions(outputOptions);
+
+ return new DataGenerationJob().setInputs(inputs);
+ }
+
+ private static DataGenerationJob waitForDataGenerationJob(DataGenerationJobsClient dataGenerationJobsClient,
+ String jobId, int pollIntervalSeconds) throws InterruptedException {
+ System.out.printf("Poll job `%s` until it reaches a terminal state.", jobId);
+ DataGenerationJob job;
+ do {
+ Thread.sleep(pollIntervalSeconds * 1000L);
+ System.out.print(".");
+ job = dataGenerationJobsClient.getGenerationJob(jobId, DATA_GENERATION_PREVIEW);
+ } while (!isTerminalStatus(job.getStatus()));
+ System.out.println();
+ return job;
+ }
+
+ static boolean isTerminalStatus(JobStatus status) {
+ return JobStatus.SUCCEEDED.equals(status)
+ || JobStatus.FAILED.equals(status)
+ || JobStatus.CANCELLED.equals(status);
+ }
+
+ static DatasetDataGenerationJobOutput findDatasetOutput(DataGenerationJob job) {
+ if (job.getResult() != null && job.getResult().getOutputs() != null) {
+ for (DataGenerationJobOutput output : job.getResult().getOutputs()) {
+ if (output instanceof DatasetDataGenerationJobOutput) {
+ DatasetDataGenerationJobOutput datasetOutput = (DatasetDataGenerationJobOutput) output;
+ if (datasetOutput.getName() != null && datasetOutput.getVersion() != null) {
+ return datasetOutput;
+ }
+ }
+ }
+ }
+
+ throw new IllegalStateException(String.format("Job `%s` did not produce a dataset output.", job.getId()));
+ }
+
+ static EvalCreateParams createEvaluationParams(String modelName) {
+ Map queryProperty = new LinkedHashMap<>();
+ queryProperty.put("type", "string");
+
+ Map groundTruthProperty = new LinkedHashMap<>();
+ groundTruthProperty.put("type", "string");
+
+ Map properties = new LinkedHashMap<>();
+ properties.put("query", queryProperty);
+ properties.put("ground_truth", groundTruthProperty);
+
+ ItemSchema itemSchema = ItemSchema.builder()
+ .putAdditionalProperty("type", JsonValue.from("object"))
+ .putAdditionalProperty("properties", JsonValue.from(properties))
+ .putAdditionalProperty("required", JsonValue.from(Collections.singletonList("query")))
+ .build();
+
+ Custom dataSourceConfig = Custom.builder()
+ .itemSchema(itemSchema)
+ .includeSampleSchema(true)
+ .build();
+
+ return EvalCreateParams.builder()
+ .name("generated-qna-evaluation")
+ .dataSourceConfig(dataSourceConfig)
+ .testingCriteria(createAzureAIEvaluatorCriteria(modelName))
+ .build();
+ }
+
+ @SuppressWarnings({ "unchecked", "rawtypes" })
+ private static JsonField extends List> createAzureAIEvaluatorCriteria(String modelName) {
+ // openai-java 4.14.0 does not have a typed Azure AI evaluator testing criterion model.
+ return (JsonField) JsonValue.from(Arrays.asList(
+ createAzureAIEvaluator("coherence", "builtin.coherence", modelName,
+ mapOf("query", "{{item.query}}", "response", "{{sample.output_text}}")),
+ createAzureAIEvaluator("fluency", "builtin.fluency", modelName,
+ Collections.singletonMap("response", "{{sample.output_text}}"))));
+ }
+
+ private static Map createAzureAIEvaluator(String name, String evaluatorName, String modelName,
+ Map dataMapping) {
+ Map initializationParameters = new LinkedHashMap<>();
+ initializationParameters.put("deployment_name", modelName);
+
+ Map evaluator = new LinkedHashMap<>();
+ evaluator.put("type", "azure_ai_evaluator");
+ evaluator.put("name", name);
+ evaluator.put("evaluator_name", evaluatorName);
+ evaluator.put("initialization_parameters", initializationParameters);
+ evaluator.put("data_mapping", dataMapping);
+ return evaluator;
+ }
+
+ static RunCreateParams createEvaluationRunParams(String evalId, String datasetId, String modelName) {
+ CreateEvalCompletionsRunDataSource.InputMessages.Template inputMessages
+ = CreateEvalCompletionsRunDataSource.InputMessages.Template.builder()
+ .addTemplate(EasyInputMessage.builder()
+ .role(EasyInputMessage.Role.DEVELOPER)
+ .content("You are a Contoso customer-support assistant. Answer the user's question about the "
+ + "Contoso refund policy clearly and concisely.")
+ .build())
+ .addTemplate(EasyInputMessage.builder()
+ .role(EasyInputMessage.Role.USER)
+ .content("{{item.query}}")
+ .build())
+ .build();
+
+ CreateEvalCompletionsRunDataSource dataSource = CreateEvalCompletionsRunDataSource.builder()
+ .fileIdSource(datasetId)
+ .type(CreateEvalCompletionsRunDataSource.Type.COMPLETIONS)
+ .inputMessages(inputMessages)
+ .model(modelName)
+ .build();
+
+ return RunCreateParams.builder()
+ .evalId(evalId)
+ .name("generated-qna-evaluation-run")
+ .dataSource(dataSource)
+ .build();
+ }
+
+ private static RunRetrieveResponse waitForEvaluationRun(OpenAIClient openAIClient, String evalId, String runId,
+ int pollIntervalSeconds) throws InterruptedException {
+ RunRetrieveResponse evalRun = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder()
+ .evalId(evalId)
+ .runId(runId)
+ .build());
+ while (!"completed".equals(evalRun.status()) && !"failed".equals(evalRun.status())) {
+ Thread.sleep(pollIntervalSeconds * 1000L);
+ evalRun = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder()
+ .evalId(evalId)
+ .runId(runId)
+ .build());
+ }
+ return evalRun;
+ }
+
+ private static void printOutputItems(OpenAIClient openAIClient, String evalId, String runId) {
+ int count = 0;
+ for (OutputItemListResponse item : openAIClient.evals().runs().outputItems().list(OutputItemListParams.builder()
+ .evalId(evalId)
+ .runId(runId)
+ .build()).autoPager()) {
+ count++;
+ System.out.printf(" item %d: status=%s | %s%n", count, item.status(), item.results());
+ }
+ System.out.printf("Output items (total: %d).%n", count);
+ }
+
+ private static Map mapOf(String firstKey, String firstValue, String secondKey, String secondValue) {
+ Map map = new LinkedHashMap<>();
+ map.put(firstKey, firstValue);
+ map.put(secondKey, secondValue);
+ return map;
+ }
+}
diff --git a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java
index e8f526ddc5a6..551b40ea0028 100644
--- a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java
+++ b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/ClientTestBase.java
@@ -55,10 +55,10 @@ protected AIProjectClientBuilder getClientBuilder(HttpClient httpClient,
builder.endpoint("https://localhost:8080").credential(new MockTokenCredential());
} else if (testMode == TestMode.RECORD) {
builder.addPolicy(interceptorManager.getRecordPolicy())
- .endpoint(Configuration.getGlobalConfiguration().get("AI_PROJECTS_ENDPOINT"))
+ .endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT"))
.credential(new DefaultAzureCredentialBuilder().build());
} else {
- builder.endpoint(Configuration.getGlobalConfiguration().get("AI_PROJECTS_ENDPOINT"))
+ builder.endpoint(Configuration.getGlobalConfiguration().get("FOUNDRY_PROJECT_ENDPOINT"))
.credential(new DefaultAzureCredentialBuilder().build());
}
diff --git a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java
index 25da419c09db..0a6c8f2ed104 100644
--- a/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java
+++ b/sdk/ai/azure-ai-projects/src/test/java/com/azure/ai/projects/SamplesTests.java
@@ -3,16 +3,31 @@
package com.azure.ai.projects;
import com.azure.ai.agents.models.PageOrder;
+import com.azure.ai.projects.models.ApiError;
import com.azure.ai.projects.models.DataGenerationJob;
+import com.azure.ai.projects.models.DatasetDataGenerationJobOutput;
+import com.azure.ai.projects.models.DatasetVersion;
import com.azure.ai.projects.models.FoundryFeaturesOptInKeys;
+import com.azure.ai.projects.models.JobStatus;
import com.azure.ai.projects.models.ModelVersion;
import com.azure.ai.projects.models.SkillDetails;
import com.azure.core.exception.ResourceNotFoundException;
import com.azure.core.http.HttpClient;
+import com.azure.core.test.TestMode;
import com.azure.core.test.annotation.LiveOnly;
import com.azure.core.util.BinaryData;
+import com.azure.core.util.Configuration;
+import com.openai.client.OpenAIClient;
+import com.openai.models.evals.EvalCreateResponse;
+import com.openai.models.evals.EvalDeleteParams;
+import com.openai.models.evals.runs.RunCreateResponse;
+import com.openai.models.evals.runs.RunRetrieveParams;
+import com.openai.models.evals.runs.RunRetrieveResponse;
+import com.openai.models.evals.runs.outputitems.OutputItemListParams;
+import com.openai.models.evals.runs.outputitems.OutputItemListResponse;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import reactor.test.StepVerifier;
@@ -22,6 +37,7 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
@@ -184,6 +200,72 @@ public void dataGenerationCreateGetCancelDeleteSample(HttpClient httpClient,
"Enable after providing FOUNDRY_MODEL_NAME and deciding whether to record this long-running preview flow.");
}
+ @Timeout(value = 20, unit = TimeUnit.MINUTES)
+ @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
+ @MethodSource("com.azure.ai.projects.TestUtils#getTestParameters")
+ public void dataGenerationJobWithEvaluationSample(HttpClient httpClient, AIProjectsServiceVersion serviceVersion)
+ throws InterruptedException {
+ AIProjectClientBuilder projectClientBuilder = getClientBuilder(httpClient, serviceVersion);
+ DataGenerationJobsClient dataGenerationJobsClient = projectClientBuilder.buildDataGenerationJobsClient();
+ DatasetsClient datasetsClient = projectClientBuilder.buildDatasetsClient();
+ OpenAIClient openAIClient = projectClientBuilder.buildOpenAIClient();
+
+ String modelName = getRecordedConfig("FOUNDRY_MODEL_NAME");
+ String datasetName = testResourceNamer.randomName("dataset-generation-eval-", 64);
+
+ DataGenerationJob job = null;
+ EvalCreateResponse eval = null;
+
+ try {
+ job = dataGenerationJobsClient.createGenerationJob(
+ DataGenerationJobWithEvaluationSample.createDataGenerationJob(modelName, datasetName),
+ DATA_GENERATION_PREVIEW, testResourceNamer.randomUuid());
+
+ job = waitForDataGenerationJob(dataGenerationJobsClient, job.getId(), 5, 180);
+ if (!JobStatus.SUCCEEDED.equals(job.getStatus())) {
+ ApiError error = job.getError();
+ String message = error == null ? "" : error.getMessage();
+ Assertions
+ .fail(String.format("Job `%s` ended with status `%s`: %s", job.getId(), job.getStatus(), message));
+ }
+
+ DatasetDataGenerationJobOutput output = DataGenerationJobWithEvaluationSample.findDatasetOutput(job);
+ DatasetVersion dataset = datasetsClient.getDatasetVersion(output.getName(), output.getVersion());
+ Assertions.assertNotNull(dataset);
+ Assertions.assertNotNull(dataset.getId());
+
+ eval = openAIClient.evals().create(DataGenerationJobWithEvaluationSample.createEvaluationParams(modelName));
+ Assertions.assertNotNull(eval);
+
+ RunCreateResponse evalRun = openAIClient.evals()
+ .runs()
+ .create(DataGenerationJobWithEvaluationSample.createEvaluationRunParams(eval.id(), dataset.getId(),
+ modelName));
+ Assertions.assertNotNull(evalRun);
+
+ RunRetrieveResponse completedRun = waitForEvaluationRun(openAIClient, eval.id(), evalRun.id(), 5, 180);
+ Assertions.assertEquals("completed", completedRun.status());
+ Assertions.assertNotNull(completedRun.resultCounts());
+
+ int outputItemCount = 0;
+ for (OutputItemListResponse ignored : openAIClient.evals()
+ .runs()
+ .outputItems()
+ .list(OutputItemListParams.builder().evalId(eval.id()).runId(evalRun.id()).build())
+ .autoPager()) {
+ outputItemCount++;
+ }
+ Assertions.assertTrue(outputItemCount > 0);
+ } finally {
+ if (eval != null) {
+ openAIClient.evals().delete(EvalDeleteParams.builder().evalId(eval.id()).build());
+ }
+ if (job != null) {
+ dataGenerationJobsClient.deleteGenerationJob(job.getId(), DATA_GENERATION_PREVIEW);
+ }
+ }
+ }
+
@Disabled("Requires FOUNDRY_MODEL_ASSET_NAME, FOUNDRY_MODEL_ASSET_VERSION, and FOUNDRY_MODEL_BLOB_URI.")
@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.projects.TestUtils#getTestParameters")
@@ -206,4 +288,42 @@ private static BinaryData createSkillPackage() throws IOException {
return BinaryData.fromBytes(outputStream.toByteArray());
}
+
+ private DataGenerationJob waitForDataGenerationJob(DataGenerationJobsClient dataGenerationJobsClient, String jobId,
+ int pollIntervalSeconds, int maxAttempts) throws InterruptedException {
+ DataGenerationJob job;
+ int attempts = 0;
+ do {
+ sleepIfRunningAgainstService(pollIntervalSeconds * 1000L);
+ job = dataGenerationJobsClient.getGenerationJob(jobId, DATA_GENERATION_PREVIEW);
+ attempts++;
+ } while (!DataGenerationJobWithEvaluationSample.isTerminalStatus(job.getStatus()) && attempts < maxAttempts);
+ return job;
+ }
+
+ private RunRetrieveResponse waitForEvaluationRun(OpenAIClient openAIClient, String evalId, String runId,
+ int pollIntervalSeconds, int maxAttempts) throws InterruptedException {
+ RunRetrieveResponse evalRun
+ = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder().evalId(evalId).runId(runId).build());
+ int attempts = 0;
+ while (!"completed".equals(evalRun.status()) && !"failed".equals(evalRun.status()) && attempts < maxAttempts) {
+ sleepIfRunningAgainstService(pollIntervalSeconds * 1000L);
+ evalRun
+ = openAIClient.evals().runs().retrieve(RunRetrieveParams.builder().evalId(evalId).runId(runId).build());
+ attempts++;
+ }
+ return evalRun;
+ }
+
+ private String getRecordedConfig(String name) {
+ if (getTestMode() == TestMode.PLAYBACK) {
+ return testResourceNamer.recordValueFromConfig(name);
+ }
+
+ String value = Configuration.getGlobalConfiguration().get(name);
+ if (getTestMode() == TestMode.RECORD) {
+ testResourceNamer.recordValueFromConfig(name);
+ }
+ return value;
+ }
}