diff --git a/src/server/model/user.ts b/src/server/model/user.ts index 50c13b8..587fea4 100644 --- a/src/server/model/user.ts +++ b/src/server/model/user.ts @@ -5,7 +5,7 @@ import { jwtVerify } from '../middleware/auth.js'; import { TRPCError } from '@trpc/server'; import { Prisma } from '@prisma/client'; import { AdapterUser } from '@auth/core/adapters'; -import { md5 } from '../utils/common.js'; +import { md5, sha256 } from '../utils/common.js'; import { logger } from '../utils/logger.js'; import { promUserCounter } from '../utils/prometheus/client.js'; @@ -341,3 +341,45 @@ export async function leaveWorkspace(userId: string, workspaceId: string) { throw new Error('Leave Workspace Failed.'); } } + +/** + * Generate User Api Key, for user to call api + */ +export async function generateUserApiKey(userId: string, expiredAt?: Date) { + const apiKey = `sk_${sha256(`${userId}.${Date.now()}`)}`; + + const result = await prisma.userApiKey.create({ + data: { + apiKey, + userId, + expiredAt, + }, + }); + + return result.apiKey; +} + +/** + * Verify User Api Key + */ +export async function verifyUserApiKey(apiKey: string) { + const result = await prisma.userApiKey.findUnique({ + where: { + apiKey, + }, + select: { + user: true, + expiredAt: true, + }, + }); + + if (result?.expiredAt && result.expiredAt.valueOf() < Date.now()) { + throw new Error('Api Key has been expired.'); + } + + if (!result) { + throw new Error('Api Key not found'); + } + + return result.user; +} diff --git a/src/server/prisma/migrations/20241103091612_add_user_api_key/migration.sql b/src/server/prisma/migrations/20241103091612_add_user_api_key/migration.sql new file mode 100644 index 0000000..129095e --- /dev/null +++ b/src/server/prisma/migrations/20241103091612_add_user_api_key/migration.sql @@ -0,0 +1,16 @@ +-- CreateTable +CREATE TABLE "UserApiKey" ( + "apiKey" VARCHAR(128) NOT NULL, + "userId" TEXT NOT NULL, + "createdAt" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMPTZ(6) NOT NULL, + "expiredAt" TIMESTAMP(3), + + CONSTRAINT "UserApiKey_pkey" PRIMARY KEY ("apiKey") +); + +-- CreateIndex +CREATE UNIQUE INDEX "UserApiKey_apiKey_key" ON "UserApiKey"("apiKey"); + +-- AddForeignKey +ALTER TABLE "UserApiKey" ADD CONSTRAINT "UserApiKey_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/src/server/prisma/schema.prisma b/src/server/prisma/schema.prisma index e009cdc..a144742 100644 --- a/src/server/prisma/schema.prisma +++ b/src/server/prisma/schema.prisma @@ -34,6 +34,17 @@ model User { accounts Account[] sessions Session[] workspaces WorkspacesOnUsers[] + apiKeys UserApiKey[] +} + +model UserApiKey { + apiKey String @id @unique @db.VarChar(128) + userId String + createdAt DateTime @default(now()) @db.Timestamptz(6) + updatedAt DateTime @updatedAt @db.Timestamptz(6) + expiredAt DateTime? + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) } model Account { diff --git a/src/server/prisma/zod/index.ts b/src/server/prisma/zod/index.ts index e5c8086..825a053 100644 --- a/src/server/prisma/zod/index.ts +++ b/src/server/prisma/zod/index.ts @@ -1,4 +1,5 @@ export * from "./user.js" +export * from "./userapikey.js" export * from "./account.js" export * from "./session.js" export * from "./verificationtoken.js" diff --git a/src/server/prisma/zod/user.ts b/src/server/prisma/zod/user.ts index 179f5c5..0b20bf1 100644 --- a/src/server/prisma/zod/user.ts +++ b/src/server/prisma/zod/user.ts @@ -1,6 +1,6 @@ import * as z from "zod" import * as imports from "./schemas/index.js" -import { CompleteAccount, RelatedAccountModelSchema, CompleteSession, RelatedSessionModelSchema, CompleteWorkspacesOnUsers, RelatedWorkspacesOnUsersModelSchema } from "./index.js" +import { CompleteAccount, RelatedAccountModelSchema, CompleteSession, RelatedSessionModelSchema, CompleteWorkspacesOnUsers, RelatedWorkspacesOnUsersModelSchema, CompleteUserApiKey, RelatedUserApiKeyModelSchema } from "./index.js" export const UserModelSchema = z.object({ id: z.string(), @@ -21,6 +21,7 @@ export interface CompleteUser extends z.infer { accounts: CompleteAccount[] sessions: CompleteSession[] workspaces: CompleteWorkspacesOnUsers[] + apiKeys: CompleteUserApiKey[] } /** @@ -32,4 +33,5 @@ export const RelatedUserModelSchema: z.ZodSchema = z.lazy(() => Us accounts: RelatedAccountModelSchema.array(), sessions: RelatedSessionModelSchema.array(), workspaces: RelatedWorkspacesOnUsersModelSchema.array(), + apiKeys: RelatedUserApiKeyModelSchema.array(), })) diff --git a/src/server/prisma/zod/userapikey.ts b/src/server/prisma/zod/userapikey.ts new file mode 100644 index 0000000..fa84cc3 --- /dev/null +++ b/src/server/prisma/zod/userapikey.ts @@ -0,0 +1,24 @@ +import * as z from "zod" +import * as imports from "./schemas/index.js" +import { CompleteUser, RelatedUserModelSchema } from "./index.js" + +export const UserApiKeyModelSchema = z.object({ + apiKey: z.string(), + userId: z.string(), + createdAt: z.date(), + updatedAt: z.date(), + expiredAt: z.date().nullish(), +}) + +export interface CompleteUserApiKey extends z.infer { + user: CompleteUser +} + +/** + * RelatedUserApiKeyModelSchema contains all relations on your model in addition to the scalars + * + * NOTE: Lazy required in case of potential circular dependencies within schema + */ +export const RelatedUserApiKeyModelSchema: z.ZodSchema = z.lazy(() => UserApiKeyModelSchema.extend({ + user: RelatedUserModelSchema, +})) diff --git a/src/server/trpc/trpc.ts b/src/server/trpc/trpc.ts index dfbef70..9159b7c 100644 --- a/src/server/trpc/trpc.ts +++ b/src/server/trpc/trpc.ts @@ -9,6 +9,7 @@ import { getSession } from '@auth/express'; import { authConfig } from '../model/auth.js'; import { get } from 'lodash-es'; import { promTrpcRequest } from '../utils/prometheus/client.js'; +import { verifyUserApiKey } from '../model/user.js'; export async function createContext({ req }: { req: Request }) { const authorization = req.headers['authorization'] ?? ''; @@ -57,16 +58,30 @@ const isUser = middleware(async (opts) => { const token = opts.ctx.token; if (token) { - try { - const user = jwtVerify(token); + if (token.startsWith('sk_')) { + // auth with api key + const user = await verifyUserApiKey(token); return opts.next({ ctx: { - user, + id: user.id, + username: user.username, + role: user.role, }, }); - } catch (err) { - throw new TRPCError({ code: 'UNAUTHORIZED', message: 'TokenInvalid' }); + } else { + // auth with jwt + try { + const user = jwtVerify(token); + + return opts.next({ + ctx: { + user, + }, + }); + } catch (err) { + throw new TRPCError({ code: 'UNAUTHORIZED', message: 'TokenInvalid' }); + } } } diff --git a/src/server/utils/__tests__/common.test.ts b/src/server/utils/__tests__/common.test.ts index e5738d9..38aa8da 100644 --- a/src/server/utils/__tests__/common.test.ts +++ b/src/server/utils/__tests__/common.test.ts @@ -1,5 +1,5 @@ import { describe, expect, test } from 'vitest'; -import { md5 } from '../common.js'; +import { md5, sha256 } from '../common.js'; describe('md5', () => { test('should return the correct md5 hash', () => { @@ -21,3 +21,15 @@ describe('md5', () => { expect(result1).not.toEqual(result2); }); }); + +describe('sha256', () => { + test('should return the correct sha256 hash', () => { + const input = 'test'; + const expectedHash = + '9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08'; + + const result = sha256(input); + + expect(result).toEqual(expectedHash); + }); +}); diff --git a/src/server/utils/common.ts b/src/server/utils/common.ts index 7ae76e3..944ccad 100644 --- a/src/server/utils/common.ts +++ b/src/server/utils/common.ts @@ -48,6 +48,13 @@ export function hashUuid(...args: string[]) { return v5(hash(...args), v5.DNS); } +export function sha512(input: string) { + return hash(input); +} + +export function sha256(input: string) { + return crypto.createHash('sha256').update(input).digest('hex'); +} /** * generate hash with md5 * which use in unimportant scene diff --git a/src/server/ws/index.ts b/src/server/ws/index.ts index c2c67c4..6c72df0 100644 --- a/src/server/ws/index.ts +++ b/src/server/ws/index.ts @@ -5,6 +5,7 @@ import { socketEventBus } from './shared.js'; import { isCuid } from '../utils/common.js'; import { logger } from '../utils/logger.js'; import { getAuthSession, UserAuthPayload } from '../model/auth.js'; +import { verifyUserApiKey } from '../model/user.js'; export function initSocketio(httpServer: HTTPServer) { const io = new SocketIOServer(httpServer, { @@ -28,12 +29,23 @@ export function initSocketio(httpServer: HTTPServer) { let user: UserAuthPayload; if (token) { - user = jwtVerify(token); - logger.info( - '[WebSocket] Authenticated via JWT:', - user.id, - user.username - ); + if (token.startsWith('sk_')) { + // auth with api key + const _user = await verifyUserApiKey(token); + + user = { + id: _user.id, + username: _user.username, + role: _user.role, + }; + } else { + user = jwtVerify(token); + logger.info( + '[WebSocket] Authenticated via JWT:', + user.id, + user.username + ); + } } else { const session = await getAuthSession( socket.request,