Skip to content

Commit

Permalink
Merge pull request #151 from mendableai/feat/rate-limits
Browse files Browse the repository at this point in the history
[Feat] Added rate limits
  • Loading branch information
nickscamara committed May 19, 2024
2 parents 713f16f + 98a39b3 commit 842e197
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 61 deletions.
4 changes: 4 additions & 0 deletions apps/api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ SUPABASE_SERVICE_TOKEN=

# Other Optionals
TEST_API_KEY= # use if you've set up authentication and want to test with a real API key
RATE_LIMIT_TEST_API_KEY_SCRAPE= # set if you'd like to test the scraping rate limit
RATE_LIMIT_TEST_API_KEY_CRAWL= # set if you'd like to test the crawling rate limit
SCRAPING_BEE_API_KEY= #Set if you'd like to use scraping Be to handle JS blocking
OPENAI_API_KEY= # add for LLM dependednt features (image alt generation, etc.)
BULL_AUTH_KEY= #
Expand All @@ -27,3 +29,5 @@ SLACK_WEBHOOK_URL= # set if you'd like to send slack server health status messag
POSTHOG_API_KEY= # set if you'd like to send posthog events like job logs
POSTHOG_HOST= # set if you'd like to send posthog events like job logs

STRIPE_PRICE_ID_STANDARD=
STRIPE_PRICE_ID_SCALE=
61 changes: 61 additions & 0 deletions apps/api/src/__tests__/e2e_withAuth/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -955,4 +955,65 @@ describe("E2E Tests for API Routes", () => {
expect(response.body).toHaveProperty("isProduction");
});
});

describe("Rate Limiter", () => {
it("should return 429 when rate limit is exceeded for preview token", async () => {
for (let i = 0; i < 5; i++) {
const response = await request(TEST_URL)
.post("/v0/scrape")
.set("Authorization", `Bearer this_is_just_a_preview_token`)
.set("Content-Type", "application/json")
.send({ url: "https://www.scrapethissite.com" });

expect(response.statusCode).toBe(200);
}
const response = await request(TEST_URL)
.post("/v0/scrape")
.set("Authorization", `Bearer this_is_just_a_preview_token`)
.set("Content-Type", "application/json")
.send({ url: "https://www.scrapethissite.com" });

expect(response.statusCode).toBe(429);
}, 60000);
});

// it("should return 429 when rate limit is exceeded for API key", async () => {
// for (let i = 0; i < parseInt(process.env.RATE_LIMIT_TEST_API_KEY_SCRAPE); i++) {
// const response = await request(TEST_URL)
// .post("/v0/scrape")
// .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`)
// .set("Content-Type", "application/json")
// .send({ url: "https://www.scrapethissite.com" });

// expect(response.statusCode).toBe(200);
// }

// const response = await request(TEST_URL)
// .post("/v0/scrape")
// .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`)
// .set("Content-Type", "application/json")
// .send({ url: "https://www.scrapethissite.com" });

// expect(response.statusCode).toBe(429);
// }, 60000);

// it("should return 429 when rate limit is exceeded for API key", async () => {
// for (let i = 0; i < parseInt(process.env.RATE_LIMIT_TEST_API_KEY_CRAWL); i++) {
// const response = await request(TEST_URL)
// .post("/v0/crawl")
// .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`)
// .set("Content-Type", "application/json")
// .send({ url: "https://www.scrapethissite.com" });

// expect(response.statusCode).toBe(200);
// }

// const response = await request(TEST_URL)
// .post("/v0/crawl")
// .set("Authorization", `Bearer ${process.env.TEST_API_KEY}`)
// .set("Content-Type", "application/json")
// .send({ url: "https://www.scrapethissite.com" });

// expect(response.statusCode).toBe(429);
// }, 60000);
});
124 changes: 106 additions & 18 deletions apps/api/src/controllers/auth.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { parseApi } from "../../src/lib/parseApi";
import { getRateLimiter } from "../../src/services/rate-limiter";
import { getRateLimiter, } from "../../src/services/rate-limiter";
import { AuthResponse, RateLimiterMode } from "../../src/types";
import { supabase_service } from "../../src/services/supabase";
import { withAuth } from "../../src/lib/withAuth";

import { RateLimiterRedis } from "rate-limiter-flexible";

export async function authenticateUser(req, res, mode?: RateLimiterMode) : Promise<AuthResponse> {
return withAuth(supaAuthenticateUser)(req, res, mode);
Expand All @@ -19,7 +19,6 @@ export async function supaAuthenticateUser(
error?: string;
status?: number;
}> {

const authHeader = req.headers.authorization;
if (!authHeader) {
return { success: false, error: "Unauthorized", status: 401 };
Expand All @@ -33,13 +32,85 @@ export async function supaAuthenticateUser(
};
}

const incomingIP = (req.headers["x-forwarded-for"] ||
req.socket.remoteAddress) as string;
const iptoken = incomingIP + token;

let rateLimiter: RateLimiterRedis;
let subscriptionData: { team_id: string, plan: string } | null = null;
let normalizedApi: string;

if (token == "this_is_just_a_preview_token") {
rateLimiter = getRateLimiter(RateLimiterMode.Preview, token);
} else {
normalizedApi = parseApi(token);

const { data, error } = await supabase_service.rpc(
'get_key_and_price_id_2', { api_key: normalizedApi }
);
// get_key_and_price_id_2 rpc definition:
// create or replace function get_key_and_price_id_2(api_key uuid)
// returns table(key uuid, team_id uuid, price_id text) as $$
// begin
// if api_key is null then
// return query
// select null::uuid as key, null::uuid as team_id, null::text as price_id;
// end if;

// return query
// select ak.key, ak.team_id, s.price_id
// from api_keys ak
// left join subscriptions s on ak.team_id = s.team_id
// where ak.key = api_key;
// end;
// $$ language plpgsql;

if (error) {
console.error('Error fetching key and price_id:', error);
} else {
// console.log('Key and Price ID:', data);
}

if (error || !data || data.length === 0) {
return {
success: false,
error: "Unauthorized: Invalid token",
status: 401,
};
}


subscriptionData = {
team_id: data[0].team_id,
plan: getPlanByPriceId(data[0].price_id)
}
switch (mode) {
case RateLimiterMode.Crawl:
rateLimiter = getRateLimiter(RateLimiterMode.Crawl, token, subscriptionData.plan);
break;
case RateLimiterMode.Scrape:
rateLimiter = getRateLimiter(RateLimiterMode.Scrape, token, subscriptionData.plan);
break;
case RateLimiterMode.CrawlStatus:
rateLimiter = getRateLimiter(RateLimiterMode.CrawlStatus, token);
break;
case RateLimiterMode.Search:
rateLimiter = getRateLimiter(RateLimiterMode.Search, token);
break;
case RateLimiterMode.Preview:
rateLimiter = getRateLimiter(RateLimiterMode.Preview, token);
break;
default:
rateLimiter = getRateLimiter(RateLimiterMode.Crawl, token);
break;
// case RateLimiterMode.Search:
// rateLimiter = await searchRateLimiter(RateLimiterMode.Search, token);
// break;
}
}

try {
const incomingIP = (req.headers["x-forwarded-for"] ||
req.socket.remoteAddress) as string;
const iptoken = incomingIP + token;
await getRateLimiter(
token === "this_is_just_a_preview_token" ? RateLimiterMode.Preview : mode, token
).consume(iptoken);
await rateLimiter.consume(iptoken);
} catch (rateLimiterRes) {
console.error(rateLimiterRes);
return {
Expand All @@ -66,19 +137,36 @@ export async function supaAuthenticateUser(
// return { success: false, error: "Unauthorized: Invalid token", status: 401 };
}

const normalizedApi = parseApi(token);
// make sure api key is valid, based on the api_keys table in supabase
const { data, error } = await supabase_service
if (!subscriptionData) {
normalizedApi = parseApi(token);

const { data, error } = await supabase_service
.from("api_keys")
.select("*")
.eq("key", normalizedApi);
if (error || !data || data.length === 0) {
return {
success: false,
error: "Unauthorized: Invalid token",
status: 401,
};

if (error || !data || data.length === 0) {
return {
success: false,
error: "Unauthorized: Invalid token",
status: 401,
};
}

subscriptionData = data[0];
}

return { success: true, team_id: data[0].team_id };
return { success: true, team_id: subscriptionData.team_id };
}

function getPlanByPriceId(price_id: string) {
switch (price_id) {
case process.env.STRIPE_PRICE_ID_STANDARD:
return 'standard';
case process.env.STRIPE_PRICE_ID_SCALE:
return 'scale';
default:
return 'starter';
}
}

0 comments on commit 842e197

Please sign in to comment.