60 lines
1.7 KiB
Go
60 lines
1.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
)
|
|
|
|
type ToolExec func(ctx context.Context, name, argsJSON string) string
|
|
|
|
const maxAgentIterations = 50
|
|
|
|
// Agent is a stateless driver: given a message slice and tools, it advances
|
|
// the conversation until the assistant replies with no tool calls.
|
|
type Agent struct {
|
|
Name string
|
|
Client *Client
|
|
SystemPrompt string
|
|
Tools []Tool
|
|
ToolExec ToolExec
|
|
OnToolCall func(agent string, tc ToolCall, result string)
|
|
}
|
|
|
|
// Run advances the given messages, returning the final assistant text and
|
|
// the updated message slice (including tool calls + tool results).
|
|
func (a *Agent) Run(ctx context.Context, messages []Message) (string, []Message, error) {
|
|
for i := 0; i < maxAgentIterations; i++ {
|
|
msg, err := a.Client.Chat(ctx, messages, a.Tools)
|
|
if err != nil {
|
|
return "", messages, err
|
|
}
|
|
messages = append(messages, msg)
|
|
if len(msg.ToolCalls) == 0 {
|
|
return msg.Content, messages, nil
|
|
}
|
|
for _, tc := range msg.ToolCalls {
|
|
result := a.ToolExec(ctx, tc.Function.Name, tc.Function.Arguments)
|
|
if a.OnToolCall != nil {
|
|
a.OnToolCall(a.Name, tc, result)
|
|
}
|
|
messages = append(messages, Message{
|
|
Role: "tool",
|
|
ToolCallID: tc.ID,
|
|
Name: tc.Function.Name,
|
|
Content: result,
|
|
})
|
|
}
|
|
}
|
|
return "", messages, fmt.Errorf("%s: exceeded %d iterations without final reply", a.Name, maxAgentIterations)
|
|
}
|
|
|
|
// Do is a one-shot helper: fresh conversation of system+user → final text.
|
|
func (a *Agent) Do(ctx context.Context, userMsg string) (string, error) {
|
|
messages := []Message{
|
|
{Role: "system", Content: a.SystemPrompt},
|
|
{Role: "user", Content: userMsg},
|
|
}
|
|
reply, _, err := a.Run(ctx, messages)
|
|
return reply, err
|
|
}
|