Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ideas for PromptTemplate composition and type safety improvements #5157

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
134 changes: 79 additions & 55 deletions langchain-core/src/prompts/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {
type ChatPromptValueInterface,
ChatPromptValue,
} from "../prompt_values.js";
import type { InputValues, PartialValues } from "../utils/types/index.js";
import type { InputValues, InputValues_FSTRING, PartialValues } from "../utils/types/index.js";
import { Runnable } from "../runnables/base.js";
import { BaseStringPromptTemplate } from "./string.js";
import {
Expand All @@ -29,6 +29,11 @@ import { PromptTemplate, type ParamsFromFString } from "./prompt.js";
import { ImagePromptTemplate } from "./image.js";
import { parseFString } from "./template.js";


type ObjectTupleIntersection<A, T extends [...unknown[]]> = T['length'] extends 0
? A
: ObjectTupleIntersection<A & T[0], T extends [unknown, ...infer R] ? R : []>;

/**
* Abstract class that serves as a base for creating message prompt
* templates. It defines how to format messages for different roles in a
Expand Down Expand Up @@ -86,12 +91,11 @@ export interface MessagesPlaceholderFields<T extends string> {
* extends the BaseMessagePromptTemplate.
*/
export class MessagesPlaceholder<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput extends InputValues = any
>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput extends InputValues = any
>
extends BaseMessagePromptTemplate<RunInput>
implements MessagesPlaceholderFields<Extract<keyof RunInput, string>>
{
implements MessagesPlaceholderFields<Extract<keyof RunInput, string>> {
static lc_name() {
return "MessagesPlaceholder";
}
Expand Down Expand Up @@ -209,12 +213,12 @@ export abstract class BaseMessageStringPromptTemplate<
constructor(
fields:
| MessageStringPromptTemplateFields<
InputValues<Extract<keyof RunInput, string>>
>
InputValues<Extract<keyof RunInput, string>>
>
| BaseStringPromptTemplate<
InputValues<Extract<keyof RunInput, string>>,
string
>
InputValues<Extract<keyof RunInput, string>>,
string
>
) {
if (!("prompt" in fields)) {
// eslint-disable-next-line no-param-reassign
Expand Down Expand Up @@ -312,8 +316,8 @@ export class ChatMessagePromptTemplate<
constructor(
fields:
| ChatMessagePromptTemplateFields<
InputValues<Extract<keyof RunInput, string>>
>
InputValues<Extract<keyof RunInput, string>>
>
| BaseStringPromptTemplate<InputValues<Extract<keyof RunInput, string>>>,
role?: string
) {
Expand Down Expand Up @@ -351,6 +355,25 @@ type MessageClass =

type ChatMessageClass = typeof ChatMessage;

type _StringImageMessagePromptTemplatePrompt<RunInput extends InputValues> =
| BaseStringPromptTemplate<
RunInput,
string
> | PromptTemplate<RunInput>
| Array<
| BaseStringPromptTemplate<
RunInput,
string
>
| ImagePromptTemplate<
RunInput,
string
>
| MessageStringPromptTemplateFields<
RunInput
>
>;

class _StringImageMessagePromptTemplate<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput extends InputValues = any,
Expand All @@ -360,28 +383,11 @@ class _StringImageMessagePromptTemplate<

lc_serializable = true;

inputVariables: Array<Extract<keyof RunInput, string>> = [];

additionalOptions: Record<string, unknown> = {};

prompt:
| BaseStringPromptTemplate<
InputValues<Extract<keyof RunInput, string>>,
string
>
| Array<
| BaseStringPromptTemplate<
InputValues<Extract<keyof RunInput, string>>,
string
>
| ImagePromptTemplate<
InputValues<Extract<keyof RunInput, string>>,
string
>
| MessageStringPromptTemplateFields<
InputValues<Extract<keyof RunInput, string>>
>
>;
prompt: _StringImageMessagePromptTemplatePrompt<RunInput>;

inputVariables: Array<Extract<keyof RunInput, string>> = [];

protected messageClass?: MessageClass;

Expand All @@ -396,9 +402,12 @@ class _StringImageMessagePromptTemplate<
protected chatMessageClass?: ChatMessageClass;

constructor(
/** @TODO When we come up with a better way to type prompt templates, fix this */
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fields: any,
fields:
| {
prompt: _StringImageMessagePromptTemplatePrompt<RunInput>;
additionalOptions?: Record<string, unknown>;
}
| _StringImageMessagePromptTemplatePrompt<RunInput>,
additionalOptions?: Record<string, unknown>
) {
if (!("prompt" in fields)) {
Expand Down Expand Up @@ -454,15 +463,16 @@ class _StringImageMessagePromptTemplate<
}
}

static fromTemplate(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
static fromTemplate<RunInput extends InputValues = any>(
template: string | Array<string | _TextTemplateParam | _ImageTemplateParam>,
additionalOptions?: Record<string, unknown>
) {
if (typeof template === "string") {
return new this(PromptTemplate.fromTemplate(template));
return new this(PromptTemplate.fromTemplate<RunInput>(template));
}
const prompt: Array<
PromptTemplate<InputValues> | ImagePromptTemplate<InputValues>
PromptTemplate<RunInput> | ImagePromptTemplate<RunInput>
> = [];
for (const item of template) {
if (
Expand All @@ -475,10 +485,10 @@ class _StringImageMessagePromptTemplate<
} else if (typeof item.text === "string") {
text = item.text ?? "";
}
prompt.push(PromptTemplate.fromTemplate(text));
prompt.push(PromptTemplate.fromTemplate(text) as typeof prompt[number]);
} else if (typeof item === "object" && "image_url" in item) {
let imgTemplate = item.image_url ?? "";
let imgTemplateObject: ImagePromptTemplate<InputValues>;
let imgTemplateObject: ImagePromptTemplate;
let inputVariables: string[] = [];
if (typeof imgTemplate === "string") {
const parsedTemplate = parseFString(imgTemplate);
Expand All @@ -498,7 +508,7 @@ class _StringImageMessagePromptTemplate<
}

imgTemplate = { url: imgTemplate };
imgTemplateObject = new ImagePromptTemplate<InputValues>({
imgTemplateObject = new ImagePromptTemplate({
template: imgTemplate,
inputVariables,
});
Expand All @@ -511,7 +521,7 @@ class _StringImageMessagePromptTemplate<
} else {
inputVariables = [];
}
imgTemplateObject = new ImagePromptTemplate<InputValues>({
imgTemplateObject = new ImagePromptTemplate({
template: imgTemplate,
inputVariables,
});
Expand Down Expand Up @@ -664,8 +674,13 @@ export interface ChatPromptTemplateInput<
validateTemplate?: boolean;
}

export type BaseMessagePromptTemplateLike =
| BaseMessagePromptTemplate
type ChatPromptTemplateLike<RunInput extends InputValues> =
| ChatPromptTemplate<RunInput, string>
| BaseMessagePromptTemplateLike<RunInput>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type BaseMessagePromptTemplateLike<RunInput extends InputValues = any> =
| BaseMessagePromptTemplate<RunInput>
| BaseMessageLike;

function _isBaseMessagePromptTemplate(
Expand Down Expand Up @@ -765,14 +780,13 @@ function isMessagesPlaceholder(
* ```
*/
export class ChatPromptTemplate<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput extends InputValues = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
PartialVariableName extends string = any
>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput extends InputValues = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
PartialVariableName extends string = any
>
extends BaseChatPromptTemplate<RunInput, PartialVariableName>
implements ChatPromptTemplateInput<RunInput, PartialVariableName>
{
implements ChatPromptTemplateInput<RunInput, PartialVariableName> {
static lc_name() {
return "ChatPromptTemplate";
}
Expand Down Expand Up @@ -944,14 +958,14 @@ export class ChatPromptTemplate<
*/
static fromTemplate<
// eslint-disable-next-line @typescript-eslint/ban-types
RunInput extends InputValues = Symbol,
RunInput extends InputValues = InputValues_FSTRING,
T extends string = string
>(template: T) {
const prompt = PromptTemplate.fromTemplate(template);
const prompt = PromptTemplate.fromTemplate<RunInput>(template);
const humanTemplate = new HumanMessagePromptTemplate({ prompt });
return this.fromMessages<
// eslint-disable-next-line @typescript-eslint/ban-types
RunInput extends Symbol ? ParamsFromFString<T> : RunInput
RunInput extends InputValues_FSTRING ? ParamsFromFString<T> : RunInput
>([humanTemplate]);
}

Expand Down Expand Up @@ -1012,6 +1026,16 @@ export class ChatPromptTemplate<
});
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
static fromTypedMessages<const Prompts extends ChatPromptTemplateLike<any>[]>(messages: Prompts): ChatPromptTemplate<
// eslint-disable-next-line @typescript-eslint/ban-types
ObjectTupleIntersection<{}, {
[P in keyof Prompts]: Prompts[P] extends ChatPromptTemplateLike<infer I> ? I : never
}>
> {
return this.fromMessages(messages);
}

/** @deprecated Renamed to .fromMessages */
// eslint-disable-next-line @typescript-eslint/no-explicit-any
static fromPromptMessages<RunInput extends InputValues = any>(
Expand Down
10 changes: 7 additions & 3 deletions langchain-core/src/prompts/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ import {
type TemplateFormat,
} from "./template.js";
import type { SerializedPromptTemplate } from "./serde.js";
import type { InputValues, PartialValues } from "../utils/types/index.js";
import type {
InputValues,
InputValues_FSTRING,
PartialValues,
} from "../utils/types/index.js";
import { MessageContent } from "../messages/index.js";

/**
Expand Down Expand Up @@ -183,7 +187,7 @@ export class PromptTemplate<
*/
static fromTemplate<
// eslint-disable-next-line @typescript-eslint/ban-types
RunInput extends InputValues = Symbol,
RunInput extends InputValues = InputValues_FSTRING,
T extends string = string
>(
template: T,
Expand All @@ -203,7 +207,7 @@ export class PromptTemplate<
});
return new PromptTemplate<
// eslint-disable-next-line @typescript-eslint/ban-types
RunInput extends Symbol ? ParamsFromFString<T> : RunInput
RunInput extends InputValues_FSTRING ? ParamsFromFString<T> : RunInput
>({
// Rely on extracted types
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down
26 changes: 26 additions & 0 deletions langchain-core/src/prompts/tests/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,29 @@ test("Multi-modal, multi part chat prompt works with instances of BaseMessage",
});
expect(messages).toMatchSnapshot();
});

test("fromTypedMessages combines message input types", () => {
const x = new SystemMessagePromptTemplate<{ x: string }>({
prompt: PromptTemplate.fromTemplate(""),
});
const y = new SystemMessagePromptTemplate<{ y: string }>({
prompt: PromptTemplate.fromTemplate(""),
});

const xy = ChatPromptTemplate.fromTypedMessages([x, y]);
((test: ChatPromptTemplate<{ x: string; y: string }>) => test)(xy);
});

test("_StringImageMessagePromptTemplate infers RunInput from child prompt", () => {
const x = new SystemMessagePromptTemplate({
prompt: PromptTemplate.fromTemplate<{ x: number; y: boolean }>(""),
});

((test: SystemMessagePromptTemplate<{ x: number; y: boolean }>) => test)(x);

// @ts-expect-error - nope
((test: SystemMessagePromptTemplate<{ x: number; z: number }>) => test)(x);

// @ts-expect-error - nope
((test: SystemMessagePromptTemplate<{ x: string; y: string }>) => test)(x);
});
4 changes: 3 additions & 1 deletion langchain-core/src/utils/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ export * from "./is_zod_schema.js";
export type StringWithAutocomplete<T> = T | (string & Record<never, never>);

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type InputValues<K extends string = string> = Record<K, any>;
export type InputValues<K extends string = string, V = any> = Record<K, V>;

export type InputValues_FSTRING = InputValues & { __FSTRING: true };

export type PartialValues<K extends string = string> = Record<
K,
Expand Down
4 changes: 3 additions & 1 deletion langchain/src/experimental/prompts/custom_format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import {
TypedPromptInputValues,
} from "@langchain/core/prompts";

export type InputValues_FSTRING = InputValues & { __FSTRING: true };

export type CustomFormatPromptTemplateInput<RunInput extends InputValues> =
Omit<PromptTemplateInput<RunInput, string>, "templateFormat"> & {
customParser: (template: string) => ParsedFStringNode[];
Expand Down Expand Up @@ -74,7 +76,7 @@ export class CustomFormatPromptTemplate<
}
}
// eslint-disable-next-line @typescript-eslint/ban-types
return new this<RunInput extends Symbol ? never : RunInput>({
return new this<RunInput extends InputValues_FSTRING ? never : RunInput>({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
inputVariables: [...names] as any[],
template,
Expand Down