Skip to content

Commit

Permalink
Implement serde, getTuple(), and list() in BaseCheckpointSaver (
Browse files Browse the repository at this point in the history
#119)

* Update BaseCheckpointSaver class  with getTuple() and list() functions.

* Fix bug when selecting max ts (string).

* Remove console.log().

* Fix bug in put(). Add unit tests.

* Fix format errors.

* Fix lint errors.

* Fix prettier format issues.

* Create NoopSerializer for MemorySaver checkpointer.

* Implement SqliteSaver.

* Implement put() and list() for SqliteSaver and add unit tests.

* Change default checkpoint configuration to END_OF_STEP.

* Refactor type casting for Row type and remove isSetup field.

* Add unit test for verifying parentTs in sqlite db.

* Change Sqlite schema to snake_case to match Python implementation.

* Fix race condition bug with SqliteSaver constructor.

* Implement generics for checkpoint saver and serde protocol.

* Fix type error in Graph classes.

* Require input type of serializer for checkpointers to be Checkpoint type.
  • Loading branch information
andrewnguonly committed Apr 23, 2024
1 parent 79345a0 commit c27db27
Show file tree
Hide file tree
Showing 13 changed files with 1,554 additions and 772 deletions.
4 changes: 3 additions & 1 deletion langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/core": "^0.1.51"
"@langchain/core": "^0.1.51",
"better-sqlite3": "^9.5.0"
},
"devDependencies": {
"@jest/globals": "^29.5.0",
Expand All @@ -47,6 +48,7 @@
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
"@tsconfig/recommended": "^1.0.3",
"@types/better-sqlite3": "^7.6.9",
"@typescript-eslint/eslint-plugin": "^6.12.0",
"@typescript-eslint/parser": "^6.12.0",
"dotenv": "^16.3.1",
Expand Down
20 changes: 1 addition & 19 deletions langgraph/src/channels/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Checkpoint } from "../checkpoint/index.js";
import { Checkpoint, deepCopy } from "../checkpoint/index.js";

export abstract class BaseChannel<
ValueType = unknown,
Expand Down Expand Up @@ -63,24 +63,6 @@ export class InvalidUpdateError extends Error {
}
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function deepCopy(obj: any): any {
if (typeof obj !== "object" || obj === null) {
return obj;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const newObj: any = Array.isArray(obj) ? [] : {};

for (const key in obj) {
if (Object.prototype.hasOwnProperty.call(obj, key)) {
newObj[key] = deepCopy(obj[key]);
}
}

return newObj;
}

export function emptyChannels(
channels: Record<string, BaseChannel>,
checkpoint: Checkpoint
Expand Down
81 changes: 75 additions & 6 deletions langgraph/src/checkpoint/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ export interface Checkpoint {
versionsSeen: Record<string, Record<string, number>>;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function deepCopy(obj: any): any {
if (typeof obj !== "object" || obj === null) {
return obj;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const newObj: any = Array.isArray(obj) ? [] : {};

for (const key in obj) {
if (Object.prototype.hasOwnProperty.call(obj, key)) {
newObj[key] = deepCopy(obj[key]);
}
}

return newObj;
}

export function emptyCheckpoint(): Checkpoint {
return {
v: 1,
Expand All @@ -56,7 +74,7 @@ export function copyCheckpoint(checkpoint: Checkpoint): Checkpoint {
ts: checkpoint.ts,
channelValues: { ...checkpoint.channelValues },
channelVersions: { ...checkpoint.channelVersions },
versionsSeen: { ...checkpoint.versionsSeen },
versionsSeen: deepCopy(checkpoint.versionsSeen),
};
}

Expand All @@ -65,14 +83,65 @@ export const enum CheckpointAt {
END_OF_RUN = "end_of_run",
}

export abstract class BaseCheckpointSaver {
at: CheckpointAt = CheckpointAt.END_OF_RUN;
export interface CheckpointTuple {
config: RunnableConfig;
checkpoint: Checkpoint;
parentConfig?: RunnableConfig;
}

const CheckpointThreadId: ConfigurableFieldSpec = {
id: "threadId",
annotation: typeof "",
name: "Thread ID",
description: null,
default: "",
isShared: true,
dependencies: null,
};

const CheckpointThreadTs: ConfigurableFieldSpec = {
id: "threadTs",
annotation: typeof "",
name: "Thread Timestamp",
description:
"Pass to fetch a past checkpoint. If None, fetches the latest checkpoint.",
default: null,
isShared: true,
dependencies: null,
};

export interface SerializerProtocol<D, L> {
dumps(obj: D): L;
loads(data: L): D;
}

export abstract class BaseCheckpointSaver<L> {
at: CheckpointAt = CheckpointAt.END_OF_STEP;

serde: SerializerProtocol<Checkpoint, L>;

constructor(serde?: SerializerProtocol<Checkpoint, L>, at?: CheckpointAt) {
this.serde = serde || this.serde;
this.at = at || this.at;
}

get configSpecs(): Array<ConfigurableFieldSpec> {
return [];
return [CheckpointThreadId, CheckpointThreadTs];
}

abstract get(config: RunnableConfig): Checkpoint | undefined;
async get(config: RunnableConfig): Promise<Checkpoint | undefined> {
const value = await this.getTuple(config);
return value ? value.checkpoint : undefined;
}

abstract getTuple(
config: RunnableConfig
): Promise<CheckpointTuple | undefined>;

abstract list(config: RunnableConfig): AsyncGenerator<CheckpointTuple>;

abstract put(config: RunnableConfig, checkpoint: Checkpoint): void;
abstract put(
config: RunnableConfig,
checkpoint: Checkpoint
): Promise<RunnableConfig>;
}
1 change: 1 addition & 0 deletions langgraph/src/checkpoint/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export {
type ConfigurableFieldSpec,
type Checkpoint,
type CheckpointAt,
deepCopy,
emptyCheckpoint,
BaseCheckpointSaver,
} from "./base.js";
114 changes: 92 additions & 22 deletions langgraph/src/checkpoint/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,103 @@ import {
BaseCheckpointSaver,
Checkpoint,
CheckpointAt,
ConfigurableFieldSpec,
CheckpointTuple,
copyCheckpoint,
SerializerProtocol,
} from "./base.js";

export class MemorySaver extends BaseCheckpointSaver {
storage: Record<string, Checkpoint> = {};

get configSpecs(): ConfigurableFieldSpec[] {
return [
{
id: "threadId",
name: "Thread ID",
annotation: null,
description: null,
default: null,
isShared: true,
dependencies: null,
},
];
export class NoopSerializer
implements SerializerProtocol<Checkpoint, Checkpoint>
{
dumps(obj: Checkpoint): Checkpoint {
return obj;
}

loads(data: Checkpoint): Checkpoint {
return data;
}
}

export class MemorySaver extends BaseCheckpointSaver<Checkpoint> {
serde = new NoopSerializer();

storage: Record<string, Record<string, Checkpoint>>;

constructor(
serde?: SerializerProtocol<Checkpoint, Checkpoint>,
at?: CheckpointAt
) {
super(serde, at);
this.storage = {};
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
const threadId = config.configurable?.threadId;
const threadTs = config.configurable?.threadTs;
const checkpoints = this.storage[threadId];

if (threadTs) {
const checkpoint = checkpoints[threadTs];
if (checkpoint) {
return {
config,
checkpoint: this.serde.loads(checkpoint),
};
}
} else {
if (checkpoints) {
const maxThreadTs = Object.keys(checkpoints).sort((a, b) =>
b.localeCompare(a)
)[0];
return {
config: { configurable: { threadId, threadTs: maxThreadTs } },
checkpoint: this.serde.loads(checkpoints[maxThreadTs.toString()]),
};
}
}

return undefined;
}

get(config: RunnableConfig): Checkpoint | undefined {
return this.storage[config.configurable?.threadId];
async *list(config: RunnableConfig): AsyncGenerator<CheckpointTuple> {
const threadId = config.configurable?.threadId;
const checkpoints = this.storage[threadId] ?? {};

// sort in desc order
for (const [threadTs, checkpoint] of Object.entries(checkpoints).sort(
(a, b) => b[0].localeCompare(a[0])
)) {
yield {
config: { configurable: { threadId, threadTs } },
checkpoint: this.serde.loads(checkpoint),
};
}
}

put(config: RunnableConfig, checkpoint: Checkpoint): void {
this.storage[config.configurable?.threadId] = checkpoint;
async put(
config: RunnableConfig,
checkpoint: Checkpoint
): Promise<RunnableConfig> {
const threadId = config.configurable?.threadId;

if (this.storage[threadId]) {
this.storage[threadId][checkpoint.ts] = this.serde.dumps(checkpoint);
} else {
this.storage[threadId] = {
[checkpoint.ts]: this.serde.dumps(checkpoint),
};
}

return {
configurable: {
threadId,
threadTs: checkpoint.ts,
},
};
}
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export class MemorySaverAssertImmutable extends MemorySaver {
storageForCopies: Record<string, Record<string, Checkpoint>> = {};

Expand All @@ -42,13 +109,16 @@ export class MemorySaverAssertImmutable extends MemorySaver {
this.at = CheckpointAt.END_OF_STEP;
}

put(config: RunnableConfig, checkpoint: Checkpoint): void {
async put(
config: RunnableConfig,
checkpoint: Checkpoint
): Promise<RunnableConfig> {
const threadId = config.configurable?.threadId;
if (!this.storageForCopies[threadId]) {
this.storageForCopies[threadId] = {};
}
// assert checkpoint hasn't been modified since last written
const saved = super.get(config);
const saved = await super.get(config);
if (saved) {
const savedTs = saved.ts;
if (this.storageForCopies[threadId][savedTs]) {
Expand Down

0 comments on commit c27db27

Please sign in to comment.