Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neon_data_models/models/api/http/brainforge.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def to_llm_inference_http_request(self) -> LLMGetInferenceHttpRequest:
query=query,
history=history,
persona=self.persona,
**self.extra_body
extra_body=self.extra_body
)


Expand Down
56 changes: 37 additions & 19 deletions neon_data_models/models/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import List, Tuple, Optional, Literal
from typing import Any, Dict, List, Tuple, Optional, Literal
from pydantic import Field, model_validator, computed_field

from neon_data_models.models.base import BaseModel
Expand Down Expand Up @@ -80,7 +80,6 @@ def validate_request(self):

class LLMRequest(BaseModel):
query: str = Field(description="Incoming user prompt")
# TODO: History may support more options in the future
history: List[Tuple[LlmMessageRole, str]] = Field(
description="Formatted chat history (excluding system prompt). Note "
"that the roles used here will differ from those used in "
Expand All @@ -96,23 +95,31 @@ class LLMRequest(BaseModel):
description="Temperature of response. 0 guarantees reproducibility, "
"higher values increase variability. Must be `0.0` if "
"`beam_search` is True")
repetition_penalty: float = Field(
default=1.0, ge=1.0, le=2.0,
description="Repetition penalty. Higher values limit repeated "
"information in responses")
stream: bool = Field(
default=None, description="Enable streaming responses. "
"Mutually exclusive with `beam_search`.")
best_of: int = Field(
default=1, ge=1,
description="Number of beams to use if `beam_search` is enabled.")
beam_search: bool = Field(
default=None, description="Enable beam search. "
"Mutually exclusive with `stream`.")
max_history: int = Field(
default=2, description="Maximum number of user/assistant "
"message pairs to include in history context. "
"Excludes system prompt and incoming query.")
extra_body: Dict[str, Any] = Field(
description="Optional dict of additional request body parameters")

@property
def repetition_penalty(self) -> float:
return self.extra_body['repetition_penalty']

@property
def beam_search(self) -> bool:
return self.extra_body['use_beam_search']

@beam_search.setter
def beam_search(self, value: bool):
self.extra_body["use_beam_search"] = value

@property
def best_of(self) -> int:
return self.extra_body['best_of']

@model_validator(mode='before')
@classmethod
Expand All @@ -125,6 +132,18 @@ def validate_inputs(cls, values):
# OpenAI `extra_body` may be included in input; parse those inputs
if values.get('use_beam_search') is not None:
values['beam_search'] = values['use_beam_search']

values.setdefault("extra_body", {})
values['extra_body'].setdefault("add_special_tokens", True)
if values.get('repetition_penalty') is not None:
values['extra_body']['repetition_penalty'] = values['repetition_penalty']
values['extra_body'].setdefault('repetition_penalty', 1.0)
if values.get('beam_search') is not None:
values['extra_body']['use_beam_search'] = values['beam_search']
values['extra_body'].setdefault('use_beam_search', None)
if values.get('best_of') is not None:
values['extra_body']['best_of'] = values['best_of']
values['extra_body'].setdefault('best_of', 1)
return values

@model_validator(mode='after')
Expand Down Expand Up @@ -156,12 +175,14 @@ def validate_request(self):
raise ValueError("Cannot enable both `stream` and "
"`beam_search`")
self.stream = False
if self.stream is None and self.beam_search is None:
if self.stream is None and self.beam_search in (None, False):
self.stream = True
self.beam_search = False
elif self.stream is None:
self.stream = False

assert isinstance(self.stream, bool)
assert isinstance(self.beam_search, bool)
assert isinstance(self.stream, bool), f"Expected `stream` to be a bool, got {type(self.stream)}"
assert isinstance(self.beam_search, bool), f"Expected `beam_search` to be a bool, got {type(self.beam_search)}"

# If beam search is enabled, temperature must be set to 0.0
if self.beam_search:
Expand Down Expand Up @@ -194,10 +215,7 @@ def to_completion_kwargs(self, mq2role: dict = None) -> dict:
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"stream": self.stream,
"extra_body": {"add_special_tokens": True,
"repetition_penalty": self.repetition_penalty,
"use_beam_search": self.beam_search,
"best_of": self.best_of}}
"extra_body": self.extra_body}


class LLMResponse(BaseModel):
Expand Down
17 changes: 16 additions & 1 deletion neon_data_models/models/api/mq/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Optional, Dict, List
from pydantic import Field
from pydantic import Field, model_validator

from neon_data_models.models.api.llm import LLMRequest, LLMPersona
from neon_data_models.models.base.contexts import MQContext
Expand All @@ -42,6 +42,11 @@ class LLMProposeRequest(MQContext, LLMRequest):
"parameter, with default behavior hard-coded into each "
"LLM module.")

@model_validator(mode="before")
@classmethod
def validate_inputs(cls, values):
return LLMRequest.validate_inputs(values)


class LLMProposeResponse(MQContext):
response: str = Field(description="LLM response to the prompt")
Expand All @@ -51,6 +56,11 @@ class LLMDiscussRequest(LLMProposeRequest):
options: Dict[str, str] = Field(
description="Mapping of participant name to response to be discussed.")

@model_validator(mode="before")
@classmethod
def validate_inputs(cls, values):
return LLMRequest.validate_inputs(values)


class LLMDiscussResponse(MQContext):
opinion: str = Field(description="LLM response to the available options.")
Expand All @@ -60,6 +70,11 @@ class LLMVoteRequest(LLMProposeRequest):
responses: List[str] = Field(
description="List of responses to choose from.")

@model_validator(mode="before")
@classmethod
def validate_inputs(cls, values):
return LLMRequest.validate_inputs(values)


class LLMVoteResponse(MQContext):
sorted_answer_indexes: List[int] = Field(
Expand Down