Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/adapter/cloudflare-workers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
*/

export { serveStatic } from './serve-static-module'
export { upgradeWebSocket } from './websocket'
export { upgradeWebSocket, createWSContext, upgradeWebSocketForDO } from './websocket'
export type { UpgradeWebSocketForDOOptions } from './websocket'
export { getConnInfo } from './conninfo'
142 changes: 141 additions & 1 deletion src/adapter/cloudflare-workers/websocket.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Hono } from '../..'
import { Context } from '../../context'
import { upgradeWebSocket } from '.'
import { upgradeWebSocket, createWSContext, upgradeWebSocketForDO } from '.'

describe('upgradeWebSocket middleware', () => {
const server = new EventTarget()
Expand Down Expand Up @@ -57,3 +57,143 @@ describe('upgradeWebSocket middleware', () => {
expect(next).toBeCalled()
})
})

describe('createWSContext for Hibernation API', () => {
it('Should wrap WebSocket into WSContext', () => {
const mockWs = {
send: vi.fn(),
close: vi.fn(),
protocol: 'graphql-ws',
readyState: 1,
url: 'wss://example.com/ws',
} as unknown as WebSocket

const wsCtx = createWSContext(mockWs)

expect(wsCtx.protocol).toBe('graphql-ws')
expect(wsCtx.readyState).toBe(1)
expect(wsCtx.url?.href).toBe('wss://example.com/ws')
expect(wsCtx.raw).toBe(mockWs)
})

it('Should forward send() to underlying WebSocket', () => {
const mockWs = {
send: vi.fn(),
close: vi.fn(),
protocol: null,
readyState: 1,
url: null,
} as unknown as WebSocket

const wsCtx = createWSContext(mockWs)
wsCtx.send('hello')

expect(mockWs.send).toHaveBeenCalledWith('hello')
})

it('Should forward close() with code and reason', () => {
const mockWs = {
send: vi.fn(),
close: vi.fn(),
protocol: null,
readyState: 1,
url: null,
} as unknown as WebSocket

const wsCtx = createWSContext(mockWs)
wsCtx.close(1000, 'Normal closure')

expect(mockWs.close).toHaveBeenCalledWith(1000, 'Normal closure')
})

it('Should handle null url', () => {
const mockWs = {
send: vi.fn(),
close: vi.fn(),
protocol: null,
readyState: 1,
url: null,
} as unknown as WebSocket

const wsCtx = createWSContext(mockWs)

expect(wsCtx.url).toBe(null)
})
})

describe('upgradeWebSocketForDO', () => {
// Store original Response
const OriginalResponse = globalThis.Response

beforeAll(() => {
// @ts-expect-error Cloudflare API mock
globalThis.WebSocketPair = class {
0: WebSocket
1: WebSocket
constructor() {
this[0] = { client: true } as unknown as WebSocket
this[1] = { server: true } as unknown as WebSocket
}
}

// Mock Response to support status 101 (Cloudflare-specific)
globalThis.Response = class MockResponse {
status: number
webSocket: WebSocket | undefined
constructor(_body: BodyInit | null, init?: ResponseInit & { webSocket?: WebSocket }) {
this.status = init?.status ?? 200
this.webSocket = init?.webSocket
}
} as unknown as typeof Response
})

afterAll(() => {
globalThis.Response = OriginalResponse
})

it('Should return 101 response', () => {
const mockCtx = { acceptWebSocket: vi.fn() }

const response = upgradeWebSocketForDO(mockCtx)

expect(response.status).toBe(101)
})

it('Should call acceptWebSocket with server socket', () => {
const mockCtx = { acceptWebSocket: vi.fn() }

upgradeWebSocketForDO(mockCtx)

expect(mockCtx.acceptWebSocket).toHaveBeenCalled()
const calledWith = mockCtx.acceptWebSocket.mock.calls[0][0]
expect(calledWith).toHaveProperty('server', true)
})

it('Should pass tags to acceptWebSocket', () => {
const mockCtx = { acceptWebSocket: vi.fn() }

upgradeWebSocketForDO(mockCtx, { tags: ['room:123', 'user:456'] })

expect(mockCtx.acceptWebSocket).toHaveBeenCalledWith(expect.anything(), [
'room:123',
'user:456',
])
})

it('Should pass undefined tags when not provided', () => {
const mockCtx = { acceptWebSocket: vi.fn() }

upgradeWebSocketForDO(mockCtx)

expect(mockCtx.acceptWebSocket).toHaveBeenCalledWith(expect.anything(), undefined)
})

it('Should attach client WebSocket to response', () => {
const mockCtx = { acceptWebSocket: vi.fn() }

const response = upgradeWebSocketForDO(mockCtx)

// @ts-expect-error Cloudflare-specific property
expect(response.webSocket).toHaveProperty('client', true)
})
})
83 changes: 83 additions & 0 deletions src/adapter/cloudflare-workers/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,86 @@ export const upgradeWebSocket: UpgradeWebSocket<
webSocket: client,
})
})

/**
* Create a WSContext from a raw Cloudflare WebSocket.
* Use in Durable Object Hibernation API handlers.
*
* @example
* ```ts
* import { createWSContext } from 'hono/cloudflare-workers'
*
* export class ChatRoom extends DurableObject {
* webSocketMessage(ws: WebSocket, message: string | ArrayBuffer) {
* const wsCtx = createWSContext(ws)
* wsCtx.send(`Echo: ${message}`)
* }
* }
* ```
*/
export const createWSContext = (ws: WebSocket): WSContext<WebSocket> => {
return new WSContext<WebSocket>({
close: (code, reason) => ws.close(code, reason),
get protocol() {
return ws.protocol
},
raw: ws,
get readyState() {
return ws.readyState as WSReadyState
},
url: ws.url ? new URL(ws.url) : null,
send: (source) => ws.send(source),
})
}

/**
* Options for upgradeWebSocketForDO
*/
export interface UpgradeWebSocketForDOOptions {
/** Optional tags for the WebSocket (used with getWebSockets(tag)) */
tags?: string[]
}

/**
* Upgrade WebSocket in a Durable Object using Hibernation API.
* Handles WebSocketPair creation and acceptWebSocket.
*
* @param ctx - The Durable Object's state context (this.ctx)
* @param options - Optional configuration (tags)
* @returns Response with status 101 and the client WebSocket attached
*
* @example
* ```ts
* import { upgradeWebSocketForDO } from 'hono/cloudflare-workers'
*
* export class ChatRoom extends DurableObject {
* app = new Hono()
*
* constructor(ctx: DurableObjectState, env: Env) {
* super(ctx, env)
* this.app.get('/ws', (c) => upgradeWebSocketForDO(this.ctx))
* }
*
* fetch(request: Request) {
* return this.app.fetch(request)
* }
* }
* ```
*/
export const upgradeWebSocketForDO = (
ctx: { acceptWebSocket(ws: WebSocket, tags?: string[]): void },
options?: UpgradeWebSocketForDOOptions
): Response => {
// @ts-expect-error WebSocketPair is not typed
const webSocketPair = new WebSocketPair()
const client: WebSocket = webSocketPair[0]
const server: WebSocket = webSocketPair[1]

ctx.acceptWebSocket(server, options?.tags)

return new Response(null, {
status: 101,
// @ts-expect-error webSocket is not typed
webSocket: client,
})
}