Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pairwise): allow passing evaluator directly to wait for evaluation to complete #681

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions js/src/evaluation/evaluate_comparative.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,26 @@ import {
} from "../schemas.js";
import { shuffle } from "../utils/shuffle.js";
import { AsyncCaller } from "../utils/async_caller.js";
import { type evaluate } from "./index.js";
import pRetry from "p-retry";

type ExperimentResults = Awaited<ReturnType<typeof evaluate>>;

function isExperimentResultsList(
value: ExperimentResults[] | string[]
): value is ExperimentResults[] {
return value.some((x) => typeof x !== "string");
}

async function loadExperiment(
client: Client,
experiment: string | ExperimentResults
) {
const value =
typeof experiment === "string" ? experiment : experiment.experimentName;

function loadExperiment(client: Client, experiment: string) {
return client.readProject(
validate(experiment)
? { projectId: experiment }
: { projectName: experiment }
validate(value) ? { projectId: value } : { projectName: value }
);
}

Expand Down Expand Up @@ -107,7 +121,9 @@ export interface ComparisonEvaluationResults {
}

export async function evaluateComparative(
experiments: Array<string>,
experiments:
| Array<string>
| Array<Promise<ExperimentResults> | ExperimentResults>,
options: EvaluateComparativeOptions
): Promise<ComparisonEvaluationResults> {
if (experiments.length < 2) {
Expand All @@ -125,10 +141,34 @@ export async function evaluateComparative(
}

const client = options.client ?? new Client();
const resolvedExperiments = await Promise.all(experiments);

const projects = await Promise.all(
experiments.map((experiment) => loadExperiment(client, experiment))
);
const projects = await (() => {
if (!isExperimentResultsList(resolvedExperiments)) {
return Promise.all(
resolvedExperiments.map((experiment) =>
loadExperiment(client, experiment)
)
);
}

// if we know the number of runs beforehand, check if the
// number of runs in the project matches the expected number of runs
return Promise.all(
resolvedExperiments.map((experiment) =>
pRetry(
async () => {
const project = await loadExperiment(client, experiment);
if (project.run_count !== experiment?.results.length) {
dqbd marked this conversation as resolved.
Show resolved Hide resolved
throw new Error("Experiment is missing runs. Retrying.");
}
return project;
},
{ factor: 2, minTimeout: 1000, retries: 10 }
)
)
);
})();

if (new Set(projects.map((p) => p.reference_dataset_id)).size > 1) {
throw new Error("All experiments must have the same reference dataset.");
Expand Down Expand Up @@ -211,8 +251,6 @@ export async function evaluateComparative(
)
);

console.dir(experimentRuns, { depth: null });

let exampleIdsIntersect: Set<string> | undefined;
for (const runs of experimentRuns) {
const exampleIdsSet = new Set(
Expand Down
31 changes: 26 additions & 5 deletions js/src/tests/evaluate_comparative.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ beforeAll(async () => {
});

afterAll(async () => {
console.log("Deleting dataset");
// const client = new Client();
// await client.deleteDataset({ datasetName: TESTING_DATASET_NAME });
const client = new Client();
await client.deleteDataset({ datasetName: TESTING_DATASET_NAME });
});

describe("evaluate comparative", () => {
Expand Down Expand Up @@ -59,7 +58,29 @@ describe("evaluate comparative", () => {
}
);

// TODO: we should a) wait for runs to be persisted, b) allow passing runnables / traceables directly
expect(pairwise.results.length).toBeGreaterThanOrEqual(1);
expect(pairwise.results.length).toEqual(2);
});

test("pass directly", async () => {
const pairwise = await evaluateComparative(
[
evaluate((input) => ({ foo: `first:${input.input}` }), {
data: TESTING_DATASET_NAME,
}),
evaluate((input) => ({ foo: `second:${input.input}` }), {
data: TESTING_DATASET_NAME,
}),
],
{
evaluators: [
(runs) => ({
key: "latter_precedence",
scores: Object.fromEntries(runs.map((run, i) => [run.id, i % 2])),
}),
],
}
);

expect(pairwise.results.length).toEqual(2);
});
});
Loading