Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class OpenAITextEmbedding implements EmbeddingModel {

private final String apiKey;
private final String modelName;
private final int dimensions;
private final Integer dimensions;
private final ExecutionConfig defaultExecutionConfig;

private final String baseUrl;
Expand All @@ -66,7 +66,7 @@ public class OpenAITextEmbedding implements EmbeddingModel {
public OpenAITextEmbedding(
String apiKey,
String modelName,
int dimensions,
Integer dimensions,
ExecutionConfig defaultExecutionConfig,
String baseUrl) {
this.apiKey = apiKey;
Expand Down Expand Up @@ -132,15 +132,19 @@ public Mono<double[]> embed(ContentBlock block) {

OpenAIClient client = clientBuilder.build();

EmbeddingCreateParams createParams =
EmbeddingCreateParams.Builder paramsBuilder =
EmbeddingCreateParams.builder()
.model(modelName)
.dimensions(dimensions)
.encodingFormat(
EmbeddingCreateParams.EncodingFormat
.FLOAT)
.inputOfArrayOfStrings(List.of(text))
.build();
.inputOfArrayOfStrings(List.of(text));

if (dimensions != null && dimensions > 0) {
paramsBuilder.dimensions(dimensions);
}

EmbeddingCreateParams createParams = paramsBuilder.build();

log.debug(
"OpenAI embedding call: model={},"
Expand Down Expand Up @@ -182,7 +186,8 @@ public Mono<double[]> embed(ContentBlock block) {
embeddingValues);

// Validate dimension
if (embeddingArray.length != dimensions) {
if (dimensions != null
&& embeddingArray.length != dimensions) {
log.warn(
"Embedding dimension mismatch: expected={},"
+ " actual={}",
Expand Down Expand Up @@ -225,7 +230,7 @@ public String getModelName() {

@Override
public int getDimensions() {
return dimensions;
return dimensions != null ? dimensions : 0;
}

/**
Expand All @@ -234,7 +239,7 @@ public int getDimensions() {
public static class Builder {
private String apiKey;
private String modelName;
private int dimensions = 1536;
private Integer dimensions;
private ExecutionConfig defaultExecutionConfig;
private String baseUrl;

Expand Down Expand Up @@ -266,7 +271,7 @@ public Builder modelName(String modelName) {
* @param dimensions the dimension
* @return this builder instance
*/
public Builder dimensions(int dimensions) {
public Builder dimensions(Integer dimensions) {
this.dimensions = dimensions;
return this;
}
Expand Down Expand Up @@ -311,8 +316,9 @@ public OpenAITextEmbedding build() {
throw new IllegalStateException(
"modelName is required and cannot be null or empty");
}
if (dimensions <= 0) {
throw new IllegalStateException("dimensions must be positive, got: " + dimensions);
if (dimensions != null && dimensions <= 0) {
throw new IllegalStateException(
"dimensions must be positive if provided, got: " + dimensions);
}

ExecutionConfig effectiveConfig =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,19 @@ void testDimensionsValidation() {
.dimensions(-1)
.build());
}

@Test
@DisplayName("Should support optional dimensions in builder")
void testOptionalDimensions() {
OpenAITextEmbedding model =
OpenAITextEmbedding.builder()
.apiKey(TEST_API_KEY)
.modelName(TEST_MODEL_NAME)
.build();

assertNotNull(model);
assertEquals(TEST_MODEL_NAME, model.getModelName());

assertEquals(0, model.getDimensions());
}
}
Loading