Skip to content

Add POST _unified for the inference API #3313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Addressing feedback and removing response
  • Loading branch information
jonathan-buttner committed Jan 10, 2025
commit 90f9fd26b7b5dc9499db0133f54fb24e0245fa83
3 changes: 3 additions & 0 deletions specification/_types/Binary.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ export type MapboxVectorTiles = ArrayBuffer

// ES|QL columns
export type EsqlColumns = ArrayBuffer

// Streaming endpoints response
export type StreamResult = ArrayBuffer
114 changes: 0 additions & 114 deletions specification/inference/_types/Results.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,120 +88,6 @@ export class InferenceResult {
rerank?: Array<RankedDocument>
}

/**
* The function the model wants to call.
*/
export class ResultFunctionCall {
/**
* The arguments to call the function with in that the model generated in JSON format.
*/
arguments?: string
/**
* The name of the function to call.
*/
name?: string
}

/**
* The tool call made by the model.
*/
export class ResultToolCall {
index: number
/**
* The identifier of the tool call.
*/
id?: string
/**
* The function the model wants to call.
*/
function?: ResultFunctionCall
/**
* The type of the tool.
*/
type?: string
}

export class CompletionDelta {
/**
* The contents of the chunked message.
*/
content?: string
/**
* The refusal message.
*/
refusal?: string
/**
* The role of the author of the message.
*/
role?: string
/**
* The tool calls made by the model.
*/
tool_calls?: Array<ResultToolCall>
}

/**
* Represent a completion choice returned from a model.
*/
export class CompletionChoice {
/**
* The delta generated by the model.
*/
delta: CompletionDelta
/**
* The reason the model stopped generating tokens.
*/
finish_reason?: string
/**
* The index of the choice in the array of choices field.
*/
index: number
}

/**
* The token usage statistics for the entire request.
*/
export class Usage {
/**
* The number of tokens in the generated completion.
*/
completion_tokens: number
/**
* The number of tokens in the prompt.
*/
prompt_tokens: number
/**
* The sum of completion_tokens and prompt_tokens.
*/
total_tokens: number
}

/**
* Respresents the result format for a completion request using the Unified Inference API.
*/
export class UnifiedInferenceResult {
/**
* A unique identifier for the chat completion
*/
id: string
/**
* A list of completion choices.
*/
choices: Array<CompletionChoice>
/**
* The model that generated the completion.
*/
model: string
/**
* The object type.
*/
object: string
/**
* The token usage statistics for the entire request.
*/
usage?: Usage
}

/**
* Acknowledged response. For dry_run, contains the list of pipelines which reference the inference endpoint
*/
Expand Down
137 changes: 74 additions & 63 deletions specification/inference/unified_inference/UnifiedRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,74 @@ import { TaskType } from '@inference/_types/TaskType'
import { UserDefinedValue } from '@spec_utils/UserDefinedValue'
import { RequestBase } from '@_types/Base'
import { Id } from '@_types/common'
import { float, long } from '@_types/Numeric'
import { Duration } from '@_types/Time'

/**
* Perform inference on the service using the Unified Schema
* @rest_spec_name inference.unified_inference
* @availability stack since=8.18.0 stability=stable visibility=public
* @availability serverless stability=stable visibility=public
*/
export interface Request extends RequestBase {
path_parts: {
/**
* The task type
*/
task_type?: TaskType
/**
* The inference Id
*/
inference_id: Id
}
query_parameters: {
/**
* Specifies the amount of time to wait for the inference request to complete.
* @server_default 30s
*/
timeout?: Duration
}
body: {
/**
* A list of objects representing the conversation.
*/
messages: Array<Message>
/**
* The ID of the model to use.
*/
model?: string
/**
* The upper bound limit for the number of tokens that can be generated for a completion request.
*/
max_completion_tokens?: long
/**
* A sequence of strings to control when the model should stop generating additional tokens.
*/
stop?: Array<string>
/**
* The sampling temperature to use.
*/
temperature?: float
/**
* Controls which tool is called by the model.
*/
tool_choice?: CompletionToolType
/**
* A list of tools that the model can call.
*/
tools?: Array<CompletionTool>
/**
* Nucleus sampling, an alternative to sampling with temperature.
*/
top_p?: float
}
}

/**
* @codegen_names string, object
*/
export type CompletionToolType = string | CompletionToolChoice

/**
* An object style representation of a single portion of a conversation.
*/
Expand Down Expand Up @@ -58,7 +124,7 @@ export interface ToolCall {
/**
* The identifier of the tool call.
*/
id: string
id: Id
/**
* The function that the model called.
*/
Expand All @@ -69,22 +135,27 @@ export interface ToolCall {
type: string
}

/**
* @codegen_names string, object
*/
export type MessageContent = string | Array<ContentObject>

/**
* An object representing part of the conversation.
*/
export interface Message {
/**
* The content of the message.
*/
content: string | Array<ContentObject>
content?: MessageContent
/**
* The role of the message author.
*/
role: string
/**
* The tool call that this message is responding to.
*/
tool_call_id?: string
tool_call_id?: Id
/**
* The tool calls generated by the model.
*/
Expand Down Expand Up @@ -152,63 +223,3 @@ export interface CompletionTool {
*/
function: CompletionToolFunction
}

/**
* Perform inference on the service using the Unified Schema
* @rest_spec_name inference.unified_inference
* @availability stack since=8.18.0 stability=stable visibility=public
* @availability serverless stability=stable visibility=public
*/
export interface Request extends RequestBase {
path_parts: {
/**
* The task type
*/
task_type?: TaskType
/**
* The inference Id
*/
inference_id: Id
}
query_parameters: {
/**
* Specifies the amount of time to wait for the inference request to complete.
* @server_default 30s
*/
timeout?: Duration
}
body: {
/**
* A list of objects representing the conversation.
*/
messages: Array<Message>
/**
* The ID of the model to use.
*/
model?: string
/**
* The upper bound limit for the number of tokens that can be generated for a completion request.
*/
max_completion_tokens?: number
/**
* A sequence of strings to control when the model should stop generating additional tokens.
*/
stop?: Array<string>
/**
* The sampling temperature to use.
*/
temperature?: number
/**
* Controls which tool is called by the model.
*/
tool_choice?: string | CompletionToolChoice
/**
* A list of tools that the model can call.
*/
tools?: Array<CompletionTool>
/**
* Nucleus sampling, an alternative to sampling with temperature.
*/
top_p?: number
}
}
4 changes: 2 additions & 2 deletions specification/inference/unified_inference/UnifiedResponse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
* under the License.
*/

import { UnifiedInferenceResult } from '@inference/_types/Results'
import { StreamResult } from '@_types/Binary'

export class Response {
body: UnifiedInferenceResult
body: StreamResult
}