diff --git a/src/index.ts b/src/index.ts index 826f5c5e4e..9eba347ee0 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,7 +12,7 @@ import type { ReconnectContext, ReconnectPolicy } from './room/ReconnectPolicy'; import Room, { ConnectionState, type RoomEventCallbacks } from './room/Room'; import * as attributes from './room/attribute-typings'; // FIXME: remove this import in a follow up data track pull request. -import './room/data-track/depacketizer'; +import './room/data-track/incoming/IncomingDataTrackManager'; // FIXME: remove this import in a follow up data track pull request. import './room/data-track/outgoing/OutgoingDataTrackManager'; import LocalParticipant from './room/participant/LocalParticipant'; diff --git a/src/room/data-track/track.ts b/src/room/data-track/LocalDataTrack.ts similarity index 76% rename from src/room/data-track/track.ts rename to src/room/data-track/LocalDataTrack.ts index 5a2888ab5e..25a9f86c69 100644 --- a/src/room/data-track/track.ts +++ b/src/room/data-track/LocalDataTrack.ts @@ -1,22 +1,25 @@ import type { DataTrackFrame } from './frame'; -import { type DataTrackHandle } from './handle'; import type OutgoingDataTrackManager from './outgoing/OutgoingDataTrackManager'; +import { + DataTrackSymbol, + type IDataTrack, + type ILocalTrack, + TrackSymbol, +} from './track-interfaces'; +import type { DataTrackInfo } from './types'; -export type DataTrackSid = string; +export default class LocalDataTrack implements ILocalTrack, IDataTrack { + readonly trackSymbol = TrackSymbol; -/** Information about a published data track. */ -export type DataTrackInfo = { - sid: DataTrackSid; - pubHandle: DataTrackHandle; - name: String; - usesE2ee: boolean; -}; + readonly isLocal = true; + + readonly typeSymbol = DataTrackSymbol; -export class LocalDataTrack { info: DataTrackInfo; protected manager: OutgoingDataTrackManager; + /** @internal */ constructor(info: DataTrackInfo, manager: OutgoingDataTrackManager) { this.info = info; this.manager = manager; diff --git a/src/room/data-track/RemoteDataTrack.ts b/src/room/data-track/RemoteDataTrack.ts new file mode 100644 index 0000000000..3185568bdb --- /dev/null +++ b/src/room/data-track/RemoteDataTrack.ts @@ -0,0 +1,82 @@ +import type Participant from '../participant/Participant'; +import type { DataTrackFrame } from './frame'; +import type IncomingDataTrackManager from './incoming/IncomingDataTrackManager'; +import { + DataTrackSymbol, + type IDataTrack, + type IRemoteTrack, + TrackSymbol, +} from './track-interfaces'; +import { type DataTrackInfo } from './types'; + +type RemoteDataTrackOptions = { + publisherIdentity: Participant['identity']; +}; + +export type RemoteDataTrackSubscribeOptions = { + signal?: AbortSignal; + + /** The number of {@link DataTrackFrame}s to hold in the ReadableStream before disgarding extra + * frames. Defaults to 4, but this may not be good enough for especially high frequency data. */ + highWaterMark?: number; +}; + +export default class RemoteDataTrack implements IRemoteTrack, IDataTrack { + readonly trackSymbol = TrackSymbol; + + readonly isLocal = false; + + readonly typeSymbol = DataTrackSymbol; + + info: DataTrackInfo; + + publisherIdentity: Participant['identity']; + + protected manager: IncomingDataTrackManager; + + /** @internal */ + constructor( + info: DataTrackInfo, + manager: IncomingDataTrackManager, + options: RemoteDataTrackOptions, + ) { + this.info = info; + this.manager = manager; + this.publisherIdentity = options.publisherIdentity; + } + + /** Subscribes to the data track to receive frames. + * + * # Returns + * + * A stream that yields {@link DataTrackFrame}s as they arrive. + * + * # Multiple Subscriptions + * + * An application may call `subscribe` more than once to process frames in + * multiple places. For example, one async task might plot values on a graph + * while another writes them to a file. + * + * Internally, only the first call to `subscribe` communicates with the SFU and + * allocates the resources required to receive frames. Additional subscriptions + * reuse the same underlying pipeline and do not trigger additional signaling. + * + * Note that newly created subscriptions only receive frames published after + * the initial subscription is established. + */ + async subscribe( + options?: RemoteDataTrackSubscribeOptions, + ): Promise> { + try { + const stream = await this.manager.subscribeRequest( + this.info.sid, + options?.signal, + options?.highWaterMark, + ); + return stream; + } catch (err) { + // NOTE: Rethrow errors to break Throws<...> type boundary + throw err; + } + } +} diff --git a/src/room/data-track/depacketizer.ts b/src/room/data-track/depacketizer.ts index 730193afe0..d5813da9c0 100644 --- a/src/room/data-track/depacketizer.ts +++ b/src/room/data-track/depacketizer.ts @@ -21,7 +21,7 @@ type PartialFrame = { /** An error indicating a frame was dropped. */ export class DataTrackDepacketizerDropError< - Reason extends DataTrackDepacketizerDropReason, + Reason extends DataTrackDepacketizerDropReason = DataTrackDepacketizerDropReason, > extends LivekitReasonedError { readonly name = 'DataTrackDepacketizerDropError'; @@ -99,13 +99,7 @@ export default class DataTrackDepacketizer { push( packet: DataTrackPacket, options?: PushOptions, - ): Throws< - DataTrackFrame | null, - | DataTrackDepacketizerDropError - | DataTrackDepacketizerDropError - | DataTrackDepacketizerDropError - | DataTrackDepacketizerDropError - > { + ): Throws { switch (packet.header.marker) { case FrameMarker.Single: return this.frameFromSingle(packet, options); @@ -191,13 +185,7 @@ export default class DataTrackDepacketizer { /** Push to the existing partial frame. */ private pushToPartial( packet: DataTrackPacket, - ): Throws< - DataTrackFrame | null, - | DataTrackDepacketizerDropError - | DataTrackDepacketizerDropError - | DataTrackDepacketizerDropError - | DataTrackDepacketizerDropError - > { + ): Throws { if (packet.header.marker !== FrameMarker.Inter && packet.header.marker !== FrameMarker.Final) { // @throws-transformer ignore - this should be treated as a "panic" and not be caught throw new Error( diff --git a/src/room/data-track/e2ee.ts b/src/room/data-track/e2ee.ts index 29064787ae..52ed95f4b0 100644 --- a/src/room/data-track/e2ee.ts +++ b/src/room/data-track/e2ee.ts @@ -10,5 +10,6 @@ export type EncryptionProvider = { }; export type DecryptionProvider = { - decrypt(payload: Uint8Array, senderIdentity: string): Uint8Array; + // FIXME: add in explicit `Throws<..., DecryptionError>`? + decrypt(payload: EncryptedPayload, senderIdentity: string): Uint8Array; }; diff --git a/src/room/data-track/handle.ts b/src/room/data-track/handle.ts index 58f7688026..b005e79129 100644 --- a/src/room/data-track/handle.ts +++ b/src/room/data-track/handle.ts @@ -43,13 +43,7 @@ export class DataTrackHandleError< export type DataTrackHandle = number; export const DataTrackHandle = { - fromNumber( - raw: number, - ): Throws< - DataTrackHandle, - | DataTrackHandleError - | DataTrackHandleError - > { + fromNumber(raw: number): Throws { if (raw === 0) { throw DataTrackHandleError.reserved(raw); } diff --git a/src/room/data-track/incoming/IncomingDataTrackManager.test.ts b/src/room/data-track/incoming/IncomingDataTrackManager.test.ts new file mode 100644 index 0000000000..430225b4b0 --- /dev/null +++ b/src/room/data-track/incoming/IncomingDataTrackManager.test.ts @@ -0,0 +1,568 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { subscribeToEvents } from '../../../utils/subscribeToEvents'; +import { DecryptionProvider, EncryptedPayload } from '../e2ee'; +import { DataTrackFrame } from '../frame'; +import { DataTrackHandle, DataTrackHandleAllocator } from '../handle'; +import { DataTrackPacket, DataTrackPacketHeader, FrameMarker } from '../packet'; +import { DataTrackE2eeExtension, DataTrackExtensions } from '../packet/extensions'; +import { DataTrackTimestamp, WrapAroundUnsignedInt } from '../utils'; +import IncomingDataTrackManager, { + DataTrackIncomingManagerCallbacks, +} from './IncomingDataTrackManager'; +import { DataTrackSubscribeError } from './errors'; + +/** A fake "decryption" provider used for test purposes. Assumes the payload is prefixed with + * 0xdeafbeef, which is stripped off. */ +const PrefixingDecryptionProvider: DecryptionProvider = { + decrypt(p: EncryptedPayload, _senderIdentity: string) { + if ( + p.payload[0] !== 0xde || + p.payload[1] !== 0xad || + p.payload[2] !== 0xbe || + p.payload[3] !== 0xef + ) { + throw new Error( + `PrefixingDecryptionProvider: first four bytes of payload were not 0xdeadbeef, found ${p.payload.slice(0, 4)}`, + ); + } + return p.payload.slice(4); + }, +}; + +describe('DataTrackIncomingManager', () => { + describe('Track publication', () => { + it('should test track publication additions / removals', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + 'trackUnavailable', + ]); + + // 1. Add a track, make sure the track available event was sent + await manager.receiveSfuPublicationUpdates( + new Map([ + [ + 'identity1', + [ + { + sid: 'sid1', + pubHandle: DataTrackHandle.fromNumber(5), + name: 'test', + usesE2ee: false, + }, + ], + ], + ]), + ); + + const trackAvailableEvent = await managerEvents.waitFor('trackAvailable'); + expect(trackAvailableEvent.track.info.sid).toStrictEqual('sid1'); + expect(trackAvailableEvent.track.info.pubHandle).toStrictEqual(DataTrackHandle.fromNumber(5)); + expect(trackAvailableEvent.track.info.name).toStrictEqual('test'); + expect(trackAvailableEvent.track.info.usesE2ee).toStrictEqual(false); + + // 2. Check to make sure the publication has been noted in internal state + expect((await manager.queryPublications()).map((p) => p.pubHandle)).to.deep.equal([ + DataTrackHandle.fromNumber(5), + ]); + + // 3. Remove all tracks, and make sure the internal state is cleared + await manager.receiveSfuPublicationUpdates(new Map([['identity1', []]])); + expect(await manager.queryPublications()).to.deep.equal([]); + + const trackUnavailableEvent = await managerEvents.waitFor('trackUnavailable'); + expect(trackUnavailableEvent.sid).toStrictEqual('sid1'); + expect(trackUnavailableEvent.publisherIdentity).toStrictEqual('identity1'); + }); + + it('should process sfu publication updates idempotently', async () => { + const manager = new IncomingDataTrackManager(); + + // 1. Simulate three identical track publications being received + for (let i = 0; i < 3; i += 1) { + await manager.receiveSfuPublicationUpdates( + new Map([ + [ + 'identity1', + [ + { + sid: 'sid1', + pubHandle: DataTrackHandle.fromNumber(5), + name: 'test', + usesE2ee: false, + }, + ], + ], + ]), + ); + } + + // 2. Check to make sure the publication has been noted in internal state only once + expect((await manager.queryPublications()).map((p) => p.pubHandle)).to.deep.equal([ + DataTrackHandle.fromNumber(5), + ]); + }); + }); + + describe('Track subscription', () => { + it('should test data track subscribing (ok case)', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const senderIdentity = 'identity'; + const sid = 'data track sid'; + const handle = DataTrackHandle.fromNumber(5); + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([[senderIdentity, [{ sid, pubHandle: handle, name: 'test', usesE2ee: false }]]]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Subscribe to a data track + const subscribeRequestPromise = manager.subscribeRequest(sid); + + // 3. This subscribe request should be sent along to the SFU + const sfuUpdateSubscriptionEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(sfuUpdateSubscriptionEvent.sid).toStrictEqual(sid); + expect(sfuUpdateSubscriptionEvent.subscribe).toStrictEqual(true); + + // 4. Once the SFU has acknowledged the subscription, a handle is sent back representing + // the subscription + manager.receivedSfuSubscriberHandles(new Map([[handle, sid]])); + + // 5. Make sure that the subscription promise resolves. + const readableStream = await subscribeRequestPromise; + const reader = readableStream.getReader(); + + // 6. Simulate receiving a packet + manager.packetReceived( + new DataTrackPacket( + new DataTrackPacketHeader({ + extensions: new DataTrackExtensions(), + frameNumber: WrapAroundUnsignedInt.u16(0), + marker: FrameMarker.Single, + sequence: WrapAroundUnsignedInt.u16(0), + timestamp: DataTrackTimestamp.fromRtpTicks(0), + trackHandle: DataTrackHandle.fromNumber(5), + }), + new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05]), + ).toBinary(), + ); + + // 7. Make sure that packet comes out of the ReadableStream + const { value, done } = await reader.read(); + expect(done).toStrictEqual(false); + expect(value?.payload).toStrictEqual(new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05])); + }); + + it('should test data track subscribing with end to end encryption (ok case)', async () => { + const manager = new IncomingDataTrackManager({ + decryptionProvider: PrefixingDecryptionProvider, + }); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const senderIdentity = 'identity'; + const sid = 'data track sid'; + const handle = DataTrackHandle.fromNumber(5); + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([[senderIdentity, [{ sid, pubHandle: handle, name: 'test', usesE2ee: true }]]]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Subscribe to a data track + const subscribeRequestPromise = manager.subscribeRequest(sid); + + // 3. This subscribe request should be sent along to the SFU + const sfuUpdateSubscriptionEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(sfuUpdateSubscriptionEvent.sid).toStrictEqual(sid); + expect(sfuUpdateSubscriptionEvent.subscribe).toStrictEqual(true); + + // 4. Once the SFU has acknowledged the subscription, a handle is sent back representing + // the subscription + manager.receivedSfuSubscriberHandles(new Map([[handle, sid]])); + + // 5. Make sure that the subscription promise resolves. + const readableStream = await subscribeRequestPromise; + const reader = readableStream.getReader(); + + // 6. Simulate receiving a (fake) encrypted packet + manager.packetReceived( + new DataTrackPacket( + new DataTrackPacketHeader({ + extensions: new DataTrackExtensions({ + e2ee: new DataTrackE2eeExtension(0, new Uint8Array(12)), + }), + frameNumber: WrapAroundUnsignedInt.u16(0), + marker: FrameMarker.Single, + sequence: WrapAroundUnsignedInt.u16(0), + timestamp: DataTrackTimestamp.fromRtpTicks(0), + trackHandle: DataTrackHandle.fromNumber(5), + }), + new Uint8Array([ + // Fake encryption bytes prefix + 0xde, 0xad, 0xbe, 0xef, + // Actual payload + 0x01, 0x02, 0x03, 0x04, 0x05, + ]), + ).toBinary(), + ); + + // 7. Make sure that packet comes out of the ReadableStream + const { value, done } = await reader.read(); + expect(done).toStrictEqual(false); + expect(value?.payload).toStrictEqual(new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05])); + }); + + it('should fan out received events across multiple subscriptions', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const senderIdentity = 'identity'; + const sid = 'data track sid'; + + const handleAllocator = new DataTrackHandleAllocator(); + + // 1. Make sure the data track publication is registered + const pubHandle = handleAllocator.get()!; + await manager.receiveSfuPublicationUpdates( + new Map([[senderIdentity, [{ sid, pubHandle, name: 'test', usesE2ee: false }]]]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Set up lots of subscribers + const readers: Array> = []; + for (let index = 0; index < 8; index += 1) { + // Subscribe to a data track + const subscribeRequestPromise = manager.subscribeRequest(sid); + + // Make sure that the sfu interactions ONLY happen for the first subscription opened. + if (index === 0) { + // This subscribe request should be sent along to the SFU + const sfuUpdateSubscriptionEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(sfuUpdateSubscriptionEvent.sid).toStrictEqual(sid); + expect(sfuUpdateSubscriptionEvent.subscribe).toStrictEqual(true); + + // Simulate the subscribe being acknowledged by the SFU + manager.receivedSfuSubscriberHandles( + new Map([[DataTrackHandle.fromNumber(1 /* publish handle */ + index), sid]]), + ); + } + + // 5. Make sure that the subscription promise resolves. + const readableStream = await subscribeRequestPromise; + const reader = readableStream.getReader(); + readers.push(reader); + } + + // 6. Simulate receiving a packet + manager.packetReceived( + new DataTrackPacket( + new DataTrackPacketHeader({ + extensions: new DataTrackExtensions(), + frameNumber: WrapAroundUnsignedInt.u16(0), + marker: FrameMarker.Single, + sequence: WrapAroundUnsignedInt.u16(0), + timestamp: DataTrackTimestamp.fromRtpTicks(0), + trackHandle: pubHandle, + }), + new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05]), + ).toBinary(), + ); + + // 7. Make sure that packet comes out of all of the `ReadableStream`s + const results = await Promise.all(readers.map((reader) => reader.read())); + for (const { value, done } of results) { + expect(done).toStrictEqual(false); + expect(value?.payload).toStrictEqual(new Uint8Array([0x01, 0x02, 0x03, 0x04, 0x05])); + } + }); + + it('should be unable to subscribe to a non existing data track', async () => { + const manager = new IncomingDataTrackManager(); + await expect(manager.subscribeRequest('does not exist')).rejects.toThrowError( + 'Cannot subscribe to data track when disconnected', + ); + }); + + it('should terminate the sfu subscription if the abortsignal is triggered on the only subscription', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const sid = 'data track sid'; + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([ + [ + 'identity', + [{ sid, pubHandle: DataTrackHandle.fromNumber(5), name: 'test', usesE2ee: false }], + ], + ]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Subscribe to a data track + const controller = new AbortController(); + const subscribeRequestPromise = manager.subscribeRequest(sid, controller.signal); + await managerEvents.waitFor('sfuUpdateSubscription'); + + // 3. Cancel the subscription + controller.abort(); + await expect(subscribeRequestPromise).rejects.toThrowError( + 'Subscription to data track cancelled by caller', + ); + + // 4. Make sure the underlying sfu subscription is also terminated, since nothing needs it + // anymore. + const sfuUpdateSubscriptionEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(sfuUpdateSubscriptionEvent.sid).toStrictEqual(sid); + expect(sfuUpdateSubscriptionEvent.subscribe).toStrictEqual(false); + }); + + it('should NOT terminate the sfu subscription if the abortsignal is triggered on one of two active subscriptions', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const sid = 'data track sid'; + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([ + [ + 'identity', + [{ sid, pubHandle: DataTrackHandle.fromNumber(5), name: 'test', usesE2ee: false }], + ], + ]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Subscribe to a data track twice + const controllerOne = new AbortController(); + const subscribeRequestOnePromise = manager.subscribeRequest(sid, controllerOne.signal); + await managerEvents.waitFor('sfuUpdateSubscription'); // Subscription started + + const controllerTwo = new AbortController(); + manager.subscribeRequest(sid, controllerTwo.signal); + + // 3. Cancel the first subscription + controllerOne.abort(); + await expect(subscribeRequestOnePromise).rejects.toThrowError( + 'Subscription to data track cancelled by caller', + ); + + // 4. Make sure the underlying sfu subscription has not been also cancelled, there still is + // one data track subscription active + expect(managerEvents.areThereBufferedEvents('sfuUpdateSubscription')).toBe(false); + }); + + it('should terminate the sfu subscription if the abortsignal is already aborted', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + ]); + + const sid = 'data track sid'; + + // Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([ + [ + 'identity', + [{ sid, pubHandle: DataTrackHandle.fromNumber(5), name: 'test', usesE2ee: false }], + ], + ]), + ); + + // Subscribe to a data track + const subscribeRequestPromise = manager.subscribeRequest( + sid, + AbortSignal.abort(/* already aborted */), + ); + const start = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(start.subscribe).toBe(true); + + // Make sure cancellation is immediately bubbled up + expect(subscribeRequestPromise).rejects.toStrictEqual(DataTrackSubscribeError.cancelled()); + + // Make sure that there is immediately another "unsubscribe" sent + const end = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(end.subscribe).toBe(false); + }); + + it('should terminate the sfu subscription once all listeners have unsubscribed', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const sid = 'data track sid'; + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([ + [ + 'identity', + [{ sid, pubHandle: DataTrackHandle.fromNumber(5), name: 'test', usesE2ee: false }], + ], + ]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Create subscription A + const controllerA = new AbortController(); + const subscribeAPromise = manager.subscribeRequest(sid, controllerA.signal); + const startEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(startEvent.sid).toStrictEqual(sid); + expect(startEvent.subscribe).toStrictEqual(true); + + // 2. Create subscription B + const controllerB = new AbortController(); + const subscribeBPromise = manager.subscribeRequest(sid, controllerB.signal); + expect(managerEvents.areThereBufferedEvents('sfuUpdateSubscription')).toStrictEqual(false); + + // 3. Cancel the subscription A + controllerA.abort(); + expect(managerEvents.areThereBufferedEvents('sfuUpdateSubscription')).toStrictEqual(false); + + // 4. Cancel the subscription B, make sure the underlying sfu subscription is disposed + controllerB.abort(); + const endEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(endEvent.sid).toStrictEqual(sid); + expect(endEvent.subscribe).toStrictEqual(false); + + await expect(subscribeAPromise).rejects.toThrow(); + await expect(subscribeBPromise).rejects.toThrow(); + }); + + it('should terminate PENDING sfu subscriptions if the participant disconnects', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const senderIdentity = 'identity'; + const sid = 'data track sid'; + const handle = DataTrackHandle.fromNumber(5); + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([[senderIdentity, [{ sid, pubHandle: handle, name: 'test', usesE2ee: false }]]]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Begin subscribing to a data track + const promise = manager.subscribeRequest(sid); + + // 3. Simulate the remote participant disconnecting + manager.handleRemoteParticipantDisconnected(senderIdentity); + + // 4. Make sure the pending subscribe was terminated + await expect(promise).rejects.toThrowError( + 'Cannot subscribe to data track when disconnected', + ); + }); + + it('should terminate ACTIVE sfu subscriptions if the participant disconnects', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const senderIdentity = 'identity'; + const sid = 'data track sid'; + const handle = DataTrackHandle.fromNumber(5); + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([[senderIdentity, [{ sid, pubHandle: handle, name: 'test', usesE2ee: false }]]]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Subscribe to a data track, and send the handle back as if the SFU acknowledged it + const subscribeRequestPromise = manager.subscribeRequest(sid); + const sfuUpdateSubscriptionEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(sfuUpdateSubscriptionEvent.sid).toStrictEqual(sid); + expect(sfuUpdateSubscriptionEvent.subscribe).toStrictEqual(true); + manager.receivedSfuSubscriberHandles(new Map([[handle, sid]])); + + // 3. Start an active stream read for later + const reader = (await subscribeRequestPromise).getReader(); + + // 4. Simulate the remote participant disconnecting + manager.handleRemoteParticipantDisconnected(senderIdentity); + + // 5. Make sure the sfu unsubscribes + const endEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(endEvent.sid).toStrictEqual(sid); + expect(endEvent.subscribe).toStrictEqual(false); + + // 6. Make sure the in flight stream read was closed + await reader.closed; + }); + + it('should terminate the sfu subscription once all downstream ReadableStreams are cancelled', async () => { + const manager = new IncomingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuUpdateSubscription', + 'trackAvailable', + ]); + + const senderIdentity = 'identity'; + const sid = 'data track sid'; + const handle = DataTrackHandle.fromNumber(5); + + // 1. Make sure the data track publication is registered + await manager.receiveSfuPublicationUpdates( + new Map([[senderIdentity, [{ sid, pubHandle: handle, name: 'test', usesE2ee: false }]]]), + ); + await managerEvents.waitFor('trackAvailable'); + + // 2. Subscribe to a data track + const subscribeRequestPromise = manager.subscribeRequest(sid); + + // 3. This subscribe request should be sent along to the SFU + const sfuUpdateSubscriptionInitEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(sfuUpdateSubscriptionInitEvent.sid).toStrictEqual(sid); + expect(sfuUpdateSubscriptionInitEvent.subscribe).toStrictEqual(true); + + // 4. Once the SFU has acknowledged the subscription, a handle is sent back representing + // the subscription + manager.receivedSfuSubscriberHandles(new Map([[handle, sid]])); + + // 5. Make sure that the subscription promise resolves. + const readableStream = await subscribeRequestPromise; + const reader = readableStream.getReader(); + + // 6. Manually cancel the readable stream + await reader.cancel(); + + // 7. Make sure the underlying SFU subscription is terminated + const sfuUpdateSubscriptionCancelEvent = await managerEvents.waitFor('sfuUpdateSubscription'); + expect(sfuUpdateSubscriptionCancelEvent.sid).toStrictEqual(sid); + expect(sfuUpdateSubscriptionCancelEvent.subscribe).toStrictEqual(false); + }); + }); +}); diff --git a/src/room/data-track/incoming/IncomingDataTrackManager.ts b/src/room/data-track/incoming/IncomingDataTrackManager.ts new file mode 100644 index 0000000000..9fc22cc9b9 --- /dev/null +++ b/src/room/data-track/incoming/IncomingDataTrackManager.ts @@ -0,0 +1,554 @@ +import { type JoinResponse, type ParticipantUpdate } from '@livekit/protocol'; +import { EventEmitter } from 'events'; +import type TypedEmitter from 'typed-emitter'; +import { LoggerNames, getLogger } from '../../../logger'; +import { abortSignalAny, abortSignalTimeout } from '../../../utils/abort-signal-polyfill'; +import type { Throws } from '../../../utils/throws'; +import type Participant from '../../participant/Participant'; +import type RemoteParticipant from '../../participant/RemoteParticipant'; +import { Future } from '../../utils'; +import RemoteDataTrack from '../RemoteDataTrack'; +import { DataTrackDepacketizerDropError } from '../depacketizer'; +import type { DecryptionProvider } from '../e2ee'; +import type { DataTrackFrame } from '../frame'; +import { DataTrackHandle } from '../handle'; +import { DataTrackPacket } from '../packet'; +import { type DataTrackInfo, type DataTrackSid } from '../types'; +import { DataTrackSubscribeError } from './errors'; +import IncomingDataTrackPipeline from './pipeline'; + +const log = getLogger(LoggerNames.DataTracks); + +type SfuUpdateSubscription = { + /** Identifier of the affected track. */ + sid: DataTrackSid; + /** Whether to subscribe or unsubscribe. */ + subscribe: boolean; +}; + +export type DataTrackIncomingManagerCallbacks = { + /** Request sent to the SFU to update the subscription for a data track. */ + sfuUpdateSubscription: (event: SfuUpdateSubscription) => void; + + /** A track has been published by a remote participant and is available to be + * subscribed to. */ + trackAvailable: (event: { track: RemoteDataTrack }) => void; + + /** A track has been unpublished by a remote participant and can no longer be subscribed to. */ + trackUnavailable: (event: { + sid: DataTrackSid; + publisherIdentity: Participant['identity']; + }) => void; +}; + +/** Track is not subscribed to. */ +type SubscriptionStateNone = { type: 'none' }; +/** Track is being subscribed to, waiting for subscriber handle. */ +type SubscriptionStatePending = { + type: 'pending'; + completionFuture: Future; + /** The number of in flight requests waiting for this subscription state to go to "active". */ + pendingRequestCount: number; + /** A function that when called, cancels the pending subscription and moves back to "none". */ + cancel: () => void; +}; +/** Track has an active subscription. */ +type SubscriptionStateActive = { + type: 'active'; + subcriptionHandle: DataTrackHandle; + pipeline: IncomingDataTrackPipeline; + streamControllers: Set>; +}; + +type SubscriptionState = SubscriptionStateNone | SubscriptionStatePending | SubscriptionStateActive; + +/** Information and state for a remote data track. */ +type Descriptor = { + info: DataTrackInfo; + publisherIdentity: Participant['identity']; + subscription: S; +}; + +type IncomingDataTrackManagerOptions = { + /** Provider to use for decrypting incoming frame payloads. + * If none, remote tracks using end-to-end encryption will not be available + * for subscription. + */ + decryptionProvider: DecryptionProvider | null; +}; + +/** How long to wait when attempting to subscribe before timing out. */ +const SUBSCRIBE_TIMEOUT_MILLISECONDS = 10_000; + +/** Maximum number of {@link DataTrackFrame}s that are cached for each ReadableStream subscription. + * If data comes in too fast and saturates this threshold, backpressure will be applied. */ +const READABLE_STREAM_DEFAULT_HIGH_WATER_MARK = 16; + +export default class IncomingDataTrackManager extends (EventEmitter as new () => TypedEmitter) { + private decryptionProvider: DecryptionProvider | null; + + /** Mapping between track SID and descriptor. */ + private descriptors = new Map>(); + + /** Mapping between subscriber handle and track SID. + * + * This is an index that allows track descriptors to be looked up + * by subscriber handle in O(1) time, to make routing incoming packets + * a (hot code path) faster. + */ + private subscriptionHandles = new Map(); + + constructor(options?: IncomingDataTrackManagerOptions) { + super(); + this.decryptionProvider = options?.decryptionProvider ?? null; + } + + /** Client requested to subscribe to a data track. + * + * This is sent when the user calls {@link RemoteDataTrack.subscribe}. + * + * Only the first request to subscribe to a given track incurs meaningful overhead; subsequent + * requests simply attach an additional receiver to the broadcast channel, allowing them to consume + * frames from the existing subscription pipeline. + */ + async subscribeRequest( + sid: DataTrackSid, + signal?: AbortSignal, + highWaterMark = READABLE_STREAM_DEFAULT_HIGH_WATER_MARK, + ): Promise< + Throws< + ReadableStream, + DataTrackSubscribeError + > + > { + const descriptor = this.descriptors.get(sid); + if (!descriptor) { + // FIXME: maybe this should be a DataTrackSubscribeError.disconnected()? That's what happens + // here (on the caller end in the rust implementation): + // https://github.com/livekit/rust-sdks/blob/ccdc012e40f9b2cf6b677c07da7061216eb93a89/livekit-datatrack/src/remote/mod.rs#L81 + throw DataTrackSubscribeError.disconnected(); + + // FIXME: DataTrackSubscribeError.unpublished is unused both here and in rust + + // @throws-transformer ignore - this should be treated as a "panic" and not be caught + // throw new Error('Cannot subscribe to unknown track'); + } + + const waitForCompletionFuture = async ( + currentDescriptor: Descriptor, + userProvidedSignal?: AbortSignal, + timeoutSignal?: AbortSignal, + ) => { + if (currentDescriptor.subscription.type !== 'pending') { + // @throws-transformer ignore - this should be treated as a "panic" and not be caught + throw new Error( + `Descriptor for track ${sid} is not pending, found ${currentDescriptor.subscription.type}`, + ); + } + + const combinedSignal = abortSignalAny( + [userProvidedSignal, timeoutSignal].filter( + (s): s is AbortSignal => typeof s !== 'undefined', + ), + ); + + const proxiedCompletionFuture = new Future(); + currentDescriptor.subscription.completionFuture.promise + .then(() => proxiedCompletionFuture.resolve?.()) + .catch((err) => proxiedCompletionFuture.reject?.(err)); + + const onAbort = () => { + if (currentDescriptor.subscription.type !== 'pending') { + return; + } + currentDescriptor.subscription.pendingRequestCount -= 1; + + if (timeoutSignal?.aborted) { + // A timeout should apply to the underlying SFU subscription and cancel all user + // subscriptions. + currentDescriptor.subscription.cancel(); + return; + } + + if (currentDescriptor.subscription.pendingRequestCount <= 0) { + // No user subscriptions are still pending, so cancel the underlying pending `sfuUpdateSubscription` + currentDescriptor.subscription.cancel(); + return; + } + + // Other subscriptions are still pending for this data track, so just cancel this one + // active user subscription, and leave the rest of the user subscriptions alone. + proxiedCompletionFuture.reject?.(DataTrackSubscribeError.cancelled()); + }; + + if (combinedSignal.aborted) { + onAbort(); + } + combinedSignal.addEventListener('abort', onAbort); + await proxiedCompletionFuture.promise; + combinedSignal.removeEventListener('abort', onAbort); + + return this.createReadableStream(sid); + }; + + switch (descriptor.subscription.type) { + case 'none': { + descriptor.subscription = { + type: 'pending', + completionFuture: new Future(), + pendingRequestCount: 1, + cancel: () => { + const previousDescriptorSubscription = descriptor.subscription; + descriptor.subscription = { type: 'none' }; + + // Let the SFU know that the subscribe has been cancelled + this.emit('sfuUpdateSubscription', { sid, subscribe: false }); + + if (previousDescriptorSubscription.type === 'pending') { + previousDescriptorSubscription.completionFuture.reject?.( + timeoutSignal.aborted + ? DataTrackSubscribeError.timeout() + : // NOTE: the below cancelled case was introduced by web / there isn't a corresponding case in the rust version. + DataTrackSubscribeError.cancelled(), + ); + } + }, + }; + + this.emit('sfuUpdateSubscription', { sid, subscribe: true }); + + const timeoutSignal = abortSignalTimeout(SUBSCRIBE_TIMEOUT_MILLISECONDS); + + // Wait for the subscription to complete, or time out if it takes too long + const reader = await waitForCompletionFuture(descriptor, signal, timeoutSignal); + return reader; + } + case 'pending': { + descriptor.subscription.pendingRequestCount += 1; + + // Wait for the subscription to complete + const reader = await waitForCompletionFuture(descriptor, signal); + return reader; + } + case 'active': { + return this.createReadableStream(sid); + } + } + } + + /** Allocates a ReadableStream which emits when a new {@link DataTrackFrame} is received from the + * SFU. */ + private createReadableStream( + sid: DataTrackSid, + highWaterMark = READABLE_STREAM_DEFAULT_HIGH_WATER_MARK, + ) { + let streamController: ReadableStreamDefaultController | null = null; + return new ReadableStream( + { + start: (controller) => { + streamController = controller; + const descriptor = this.descriptors.get(sid); + if (!descriptor) { + log.error(`Unknown track ${sid}`); + return; + } + if (descriptor.subscription.type !== 'active') { + log.error(`Subscription for track ${sid} is not active`); + return; + } + + descriptor.subscription.streamControllers.add(controller); + }, + cancel: () => { + if (!streamController) { + log.warn(`ReadableStream subscribed to ${sid} was not started.`); + return; + } + const descriptor = this.descriptors.get(sid); + if (!descriptor) { + log.warn(`Unknown track ${sid}, skipping cancel...`); + return; + } + if (descriptor.subscription.type !== 'active') { + log.warn(`Subscription for track ${sid} is not active, skipping cancel...`); + return; + } + + descriptor.subscription.streamControllers.delete(streamController); + + // If no active stream controllers are left, also unsubscribe on the SFU end. + if (descriptor.subscription.streamControllers.size === 0) { + this.unSubscribeRequest(descriptor.info.sid); + } + }, + }, + new CountQueuingStrategy({ highWaterMark }), + ); + } + + /** + * Get information about all currently subscribed tracks. + * @internal */ + async querySubscribed() { + const descriptorInfos = Array.from(this.descriptors.values()) + .filter( + (descriptor): descriptor is Descriptor => + descriptor.subscription.type === 'active', + ) + .map( + (descriptor) => + [descriptor.info, descriptor.publisherIdentity] as [ + info: DataTrackInfo, + identity: Participant['identity'], + ], + ); + + return descriptorInfos; + } + + /** Client requested to unsubscribe from a data track. */ + unSubscribeRequest(sid: DataTrackSid) { + const descriptor = this.descriptors.get(sid); + if (!descriptor) { + // FIXME: rust implementation returns here, not throws + // @throws-transformer ignore - this should be treated as a "panic" and not be caught + throw new Error('Cannot subscribe to unknown track'); + } + + if (descriptor.subscription.type !== 'active') { + log.warn( + `Unexpected descriptor state in unSubscribeRequest, expected active, found ${descriptor.subscription?.type}`, + ); + return; + } + + for (const controller of descriptor.subscription.streamControllers) { + controller.close(); + } + + // FIXME: this might be wrong? Shouldn't this only occur if it is the last subscription to + // terminate? + const previousDescriptorSubscription = descriptor.subscription; + descriptor.subscription = { type: 'none' }; + this.subscriptionHandles.delete(previousDescriptorSubscription.subcriptionHandle); + + this.emit('sfuUpdateSubscription', { sid, subscribe: false }); + } + + /** SFU notification that track publications have changed. + * + * This event is produced from both {@link JoinResponse} and {@link ParticipantUpdate} + * to provide a complete view of remote participants' track publications: + * + * - From a `JoinResponse`, it captures the initial set of tracks published when a participant joins. + * - From a `ParticipantUpdate`, it captures subsequent changes (i.e., new tracks being + * published and existing tracks unpublished). + */ + async receiveSfuPublicationUpdates(updates: Map>) { + if (updates.size === 0) { + return; + } + + // Detect published track + const sidsInUpdate = new Set(); + for (const [publisherIdentity, infos] of updates.entries()) { + for (const info of infos) { + sidsInUpdate.add(info.sid); + if (this.descriptors.has(info.sid)) { + continue; + } + await this.handleTrackPublished(publisherIdentity, info); + } + } + + // Detect unpublished tracks + let unpublishedSids = Array.from(this.descriptors.keys()).filter( + (sid) => !sidsInUpdate.has(sid), + ); + for (const sid of unpublishedSids) { + this.handleTrackUnpublished(sid); + } + } + + /** + * Get information about all currently remotely published tracks which could be subscribed to. + * @internal */ + async queryPublications() { + return Array.from(this.descriptors.values()).map((descriptor) => descriptor.info); + } + + async handleTrackPublished(publisherIdentity: Participant['identity'], info: DataTrackInfo) { + if (this.descriptors.has(info.sid)) { + log.error(`Existing descriptor for track ${info.sid}`); + return; + } + let descriptor: Descriptor = { + info, + publisherIdentity, + subscription: { type: 'none' }, + }; + this.descriptors.set(descriptor.info.sid, descriptor); + + const track = new RemoteDataTrack(descriptor.info, this, { publisherIdentity }); + this.emit('trackAvailable', { track }); + } + + handleTrackUnpublished(sid: DataTrackSid) { + const descriptor = this.descriptors.get(sid); + if (!descriptor) { + log.error(`Unknown track ${sid}`); + return; + } + this.descriptors.delete(sid); + + if (descriptor.subscription.type === 'active') { + this.subscriptionHandles.delete(descriptor.subscription.subcriptionHandle); + } + + this.emit('trackUnavailable', { sid, publisherIdentity: descriptor.publisherIdentity }); + } + + /** SFU notification that handles have been assigned for requested subscriptions. */ + receivedSfuSubscriberHandles( + /** Mapping between track handles attached to incoming packets to the + * track SIDs they belong to. */ + mapping: Map, + ) { + for (const [handle, sid] of mapping.entries()) { + this.registerSubscriberHandle(handle, sid); + } + } + + private registerSubscriberHandle(assignedHandle: DataTrackHandle, sid: DataTrackSid) { + const descriptor = this.descriptors.get(sid); + if (!descriptor) { + log.error(`Unknown track ${sid}`); + return; + } + switch (descriptor.subscription.type) { + case 'none': { + // Handle assigned when there is no pending or active subscription is unexpected. + log.warn(`No subscription for ${sid}`); + return; + } + case 'active': { + // Update handle for an active subscription. This can occur following a full reconnect. + descriptor.subscription.subcriptionHandle = assignedHandle; + this.subscriptionHandles.set(assignedHandle, sid); + return; + } + case 'pending': { + const pipeline = new IncomingDataTrackPipeline({ + info: descriptor.info, + publisherIdentity: descriptor.publisherIdentity, + decryptionProvider: this.decryptionProvider, + }); + + const previousDescriptorSubscription = descriptor.subscription; + descriptor.subscription = { + type: 'active', + subcriptionHandle: assignedHandle, + pipeline, + streamControllers: new Set(), + }; + this.subscriptionHandles.set(assignedHandle, sid); + + previousDescriptorSubscription.completionFuture.resolve?.(); + } + } + } + + /** Packet has been received over the transport. */ + packetReceived(bytes: Uint8Array): Throws { + let packet: DataTrackPacket; + try { + [packet] = DataTrackPacket.fromBinary(bytes); + } catch (err) { + log.error(`Failed to deserialize packet: ${err}`); + return; + } + + const sid = this.subscriptionHandles.get(packet.header.trackHandle); + if (!sid) { + log.warn(`Unknown subscriber handle ${packet.header.trackHandle}`); + return; + } + + const descriptor = this.descriptors.get(sid); + if (!descriptor) { + log.error(`Missing descriptor for track ${sid}`); + return; + } + + if (descriptor.subscription.type !== 'active') { + log.warn(`Received packet for track ${sid} without active subscription`); + return; + } + + const frame = descriptor.subscription.pipeline.processPacket(packet); + if (!frame) { + // Not all packets have been received yet to form a complete frame + return; + } + + // Broadcast to all downstream subscribers + for (const controller of descriptor.subscription.streamControllers) { + if (controller.desiredSize !== null && controller.desiredSize <= 0) { + log.warn( + `Cannot send frame to subscribers: readable stream is full (desiredSize is ${controller.desiredSize}). To increase this threshold, set a higher 'options.highWaterMark' when calling .subscribe().`, + ); + continue; + } + controller.enqueue(frame); + } + } + + /** Resend all subscription updates. + * + * This must be sent after a full reconnect to ensure the SFU knows which + * tracks are subscribed to locally. + */ + resendSubscriptionUpdates() { + for (const [sid, descriptor] of this.descriptors) { + if (descriptor.subscription.type === 'none') { + continue; + } + this.emit('sfuUpdateSubscription', { sid, subscribe: true }); + } + } + + /** Called when a remote participant is disconnected so that any pending data tracks can be + * cancelled. */ + handleRemoteParticipantDisconnected(remoteParticipantIdentity: RemoteParticipant['identity']) { + for (const descriptor of this.descriptors.values()) { + if (descriptor.publisherIdentity !== remoteParticipantIdentity) { + continue; + } + switch (descriptor.subscription.type) { + case 'none': + break; + case 'pending': + descriptor.subscription.completionFuture.reject?.(DataTrackSubscribeError.disconnected()); + break; + case 'active': + this.unSubscribeRequest(descriptor.info.sid); + break; + } + } + } + + /** Shutdown the manager, ending any subscriptions. */ + shutdown() { + for (const descriptor of this.descriptors.values()) { + this.emit('trackUnavailable', { + sid: descriptor.info.sid, + publisherIdentity: descriptor.publisherIdentity, + }); + + if (descriptor.subscription.type === 'pending') { + descriptor.subscription.completionFuture.reject?.(DataTrackSubscribeError.disconnected()); + } + } + this.descriptors.clear(); + } +} diff --git a/src/room/data-track/incoming/errors.ts b/src/room/data-track/incoming/errors.ts new file mode 100644 index 0000000000..3d414b7fce --- /dev/null +++ b/src/room/data-track/incoming/errors.ts @@ -0,0 +1,57 @@ +import { LivekitReasonedError } from '../../errors'; + +export enum DataTrackSubscribeErrorReason { + /** The track has been unpublished and is no longer available */ + Unpublished = 0, + /** Request to subscribe to data track timed-out */ + Timeout = 1, + /** Cannot subscribe to data track when disconnected */ + Disconnected = 2, + /** Subscription to data track cancelled by caller */ + Cancelled = 4, +} + +export class DataTrackSubscribeError< + Reason extends DataTrackSubscribeErrorReason = DataTrackSubscribeErrorReason, +> extends LivekitReasonedError { + readonly name = 'DataTrackSubscribeError'; + + reason: Reason; + + reasonName: string; + + constructor(message: string, reason: Reason, options?: { cause?: unknown }) { + super(22, message, options); + this.reason = reason; + this.reasonName = DataTrackSubscribeErrorReason[reason]; + } + + static unpublished() { + return new DataTrackSubscribeError( + 'The track has been unpublished and is no longer available', + DataTrackSubscribeErrorReason.Unpublished, + ); + } + + static timeout() { + return new DataTrackSubscribeError( + 'Request to subscribe to data track timed-out', + DataTrackSubscribeErrorReason.Timeout, + ); + } + + static disconnected() { + return new DataTrackSubscribeError( + 'Cannot subscribe to data track when disconnected', + DataTrackSubscribeErrorReason.Disconnected, + ); + } + + // NOTE: this was introduced by web / there isn't a corresponding case in the rust version. + static cancelled() { + return new DataTrackSubscribeError( + 'Subscription to data track cancelled by caller', + DataTrackSubscribeErrorReason.Cancelled, + ); + } +} diff --git a/src/room/data-track/incoming/pipeline.ts b/src/room/data-track/incoming/pipeline.ts new file mode 100644 index 0000000000..5009f376ec --- /dev/null +++ b/src/room/data-track/incoming/pipeline.ts @@ -0,0 +1,116 @@ +import { LoggerNames, getLogger } from '../../../logger'; +import type { Throws } from '../../../utils/throws'; +import DataTrackDepacketizer, { DataTrackDepacketizerDropError } from '../depacketizer'; +import type { DecryptionProvider, EncryptedPayload } from '../e2ee'; +import type { DataTrackFrame } from '../frame'; +import { DataTrackPacket } from '../packet'; +import { type DataTrackInfo } from '../types'; + +const log = getLogger(LoggerNames.DataTracks); + +/** + * Options for creating a {@link IncomingDataTrackPipeline}. + */ +type Options = { + info: DataTrackInfo; + publisherIdentity: string; + decryptionProvider: DecryptionProvider | null; +}; + +/** + * Pipeline for an individual data track subscription. + */ +export default class IncomingDataTrackPipeline { + private publisherIdentity: string; + + private e2eeProvider: DecryptionProvider | null; + + private depacketizer: DataTrackDepacketizer; + + /** + * Creates a new pipeline with the given options. + */ + constructor(options: Options) { + const hasProvider = options.decryptionProvider !== null; + if (options.info.usesE2ee !== hasProvider) { + // @throws-transformer ignore - this should be treated as a "panic" and not be caught + throw new Error( + 'IncomingDataTrackPipeline: DataTrackInfo.usesE2ee must match presence of decryptionProvider', + ); + } + + const depacketizer = new DataTrackDepacketizer(); + + this.publisherIdentity = options.publisherIdentity; + this.e2eeProvider = options.decryptionProvider ?? null; + this.depacketizer = depacketizer; + } + + processPacket( + packet: DataTrackPacket, + ): Throws { + const frame = this.depacketize(packet); + if (!frame) { + return null; + } + + const decrypted = this.decryptIfNeeded(frame); + if (!decrypted) { + return null; + } + + return decrypted; + } + + /** + * Depacketize the given frame, log if a drop occurs. + */ + private depacketize( + packet: DataTrackPacket, + ): Throws { + let frame: DataTrackFrame | null; + try { + frame = this.depacketizer.push(packet); + } catch (err) { + // In a future version, use this to maintain drop statistics. + // FIXME: is this a good idea? + log.debug(`Data frame depacketize error: ${err}`); + return null; + } + return frame; + } + + /** + * Decrypt the frame's payload if E2EE is enabled for this track. + */ + private decryptIfNeeded(frame: DataTrackFrame): DataTrackFrame | null { + const decryption = this.e2eeProvider; + + if (!decryption) { + return frame; + } + + const e2ee = frame.extensions?.e2ee ?? null; + if (!e2ee) { + log.error('Missing E2EE meta'); + return null; + } + + const encrypted: EncryptedPayload = { + payload: frame.payload, + iv: e2ee.iv, + keyIndex: e2ee.keyIndex, + }; + + let result: Uint8Array; + try { + result = decryption.decrypt(encrypted, this.publisherIdentity); + } catch (err) { + log.error(`Error decrypting packet: ${err}`); + return null; + } + + frame.payload = result; + return frame; + } +} diff --git a/src/room/data-track/outgoing/OutgoingDataTrackManager.test.ts b/src/room/data-track/outgoing/OutgoingDataTrackManager.test.ts index cea8f6f8c7..3727d01fdc 100644 --- a/src/room/data-track/outgoing/OutgoingDataTrackManager.test.ts +++ b/src/room/data-track/outgoing/OutgoingDataTrackManager.test.ts @@ -110,6 +110,26 @@ describe('DataTrackOutgoingManager', () => { expect(publishRequestPromise).rejects.toStrictEqual(DataTrackPublishError.cancelled()); }); + it('should test track publishing (cancellation before it starts)', async () => { + const manager = new OutgoingDataTrackManager(); + const managerEvents = subscribeToEvents(manager, [ + 'sfuPublishRequest', + 'sfuUnpublishRequest', + ]); + + // Publish a data track + const publishRequestPromise = manager.publishRequest( + { name: 'test' }, + AbortSignal.abort(/* already aborted */), + ); + + // Make sure cancellation is immediately bubbled up + expect(publishRequestPromise).rejects.toStrictEqual(DataTrackPublishError.cancelled()); + + // And there were no pending sfu publish requests sent + expect(managerEvents.areThereBufferedEvents('sfuPublishRequest')).toBe(false); + }); + it.each([ // Single packet payload case [ diff --git a/src/room/data-track/outgoing/OutgoingDataTrackManager.ts b/src/room/data-track/outgoing/OutgoingDataTrackManager.ts index 244616bbb5..a841efb65a 100644 --- a/src/room/data-track/outgoing/OutgoingDataTrackManager.ts +++ b/src/room/data-track/outgoing/OutgoingDataTrackManager.ts @@ -1,13 +1,15 @@ import { EventEmitter } from 'events'; import type TypedEmitter from 'typed-emitter'; import { LoggerNames, getLogger } from '../../../logger'; +import { abortSignalAny, abortSignalTimeout } from '../../../utils/abort-signal-polyfill'; import type { Throws } from '../../../utils/throws'; import { Future } from '../../utils'; +import LocalDataTrack from '../LocalDataTrack'; import { type EncryptionProvider } from '../e2ee'; import type { DataTrackFrame } from '../frame'; import { DataTrackHandle, DataTrackHandleAllocator } from '../handle'; import { DataTrackExtensions } from '../packet/extensions'; -import { type DataTrackInfo, LocalDataTrack } from '../track'; +import { type DataTrackInfo } from '../types'; import { DataTrackPublishError, DataTrackPublishErrorReason, @@ -74,7 +76,7 @@ export type DataTrackOutgoingManagerCallbacks = { packetsAvailable: (event: OutputEventPacketsAvailable) => void; }; -type DataTrackLocalManagerOptions = { +type OutgoingDataTrackManagerOptions = { /** * Provider to use for encrypting outgoing frame payloads. * @@ -93,7 +95,7 @@ export default class OutgoingDataTrackManager extends (EventEmitter as new () => private descriptors = new Map(); - constructor(options?: DataTrackLocalManagerOptions) { + constructor(options?: OutgoingDataTrackManagerOptions) { super(); this.encryptionProvider = options?.encryptionProvider ?? null; } @@ -160,8 +162,8 @@ export default class OutgoingDataTrackManager extends (EventEmitter as new () => throw DataTrackPublishError.limitReached(); } - const timeoutSignal = AbortSignal.timeout(PUBLISH_TIMEOUT_MILLISECONDS); - const combinedSignal = signal ? AbortSignal.any([signal, timeoutSignal]) : timeoutSignal; + const timeoutSignal = abortSignalTimeout(PUBLISH_TIMEOUT_MILLISECONDS); + const combinedSignal = signal ? abortSignalAny([signal, timeoutSignal]) : timeoutSignal; if (this.descriptors.has(handle)) { // @throws-transformer ignore - this should be treated as a "panic" and not be caught @@ -191,6 +193,10 @@ export default class OutgoingDataTrackManager extends (EventEmitter as new () => ); } }; + if (combinedSignal.aborted) { + onAbort(); // NOTE: this rejects `completionFuture`; the next line just returns the rejection + return descriptor.completionFuture.promise; + } combinedSignal.addEventListener('abort', onAbort); this.emit('sfuPublishRequest', { diff --git a/src/room/data-track/outgoing/errors.ts b/src/room/data-track/outgoing/errors.ts index 2340c23593..87c23c64e3 100644 --- a/src/room/data-track/outgoing/errors.ts +++ b/src/room/data-track/outgoing/errors.ts @@ -1,5 +1,5 @@ import { LivekitReasonedError } from '../../errors'; -import { DataTrackPacketizerError, DataTrackPacketizerReason } from '../packetizer'; +import { DataTrackPacketizerError } from '../packetizer'; export enum DataTrackPublishErrorReason { /** @@ -91,7 +91,7 @@ export enum DataTrackPushFrameErrorReason { } export class DataTrackPushFrameError< - Reason extends DataTrackPushFrameErrorReason, + Reason extends DataTrackPushFrameErrorReason = DataTrackPushFrameErrorReason, > extends LivekitReasonedError { readonly name = 'DataTrackPushFrameError'; @@ -125,7 +125,7 @@ export enum DataTrackOutgoingPipelineErrorReason { } export class DataTrackOutgoingPipelineError< - Reason extends DataTrackOutgoingPipelineErrorReason, + Reason extends DataTrackOutgoingPipelineErrorReason = DataTrackOutgoingPipelineErrorReason, > extends LivekitReasonedError { readonly name = 'DataTrackOutgoingPipelineError'; @@ -139,7 +139,7 @@ export class DataTrackOutgoingPipelineError< this.reasonName = DataTrackOutgoingPipelineErrorReason[reason]; } - static packetizer(cause: DataTrackPacketizerError) { + static packetizer(cause: DataTrackPacketizerError) { return new DataTrackOutgoingPipelineError( 'Error packetizing frame', DataTrackOutgoingPipelineErrorReason.Packetizer, diff --git a/src/room/data-track/outgoing/pipeline.ts b/src/room/data-track/outgoing/pipeline.ts index f4f0c252af..8dcdd83b62 100644 --- a/src/room/data-track/outgoing/pipeline.ts +++ b/src/room/data-track/outgoing/pipeline.ts @@ -4,7 +4,7 @@ import { type DataTrackFrame } from '../frame'; import { DataTrackPacket } from '../packet'; import { DataTrackE2eeExtension } from '../packet/extensions'; import DataTrackPacketizer, { DataTrackPacketizerError } from '../packetizer'; -import type { DataTrackInfo } from '../track'; +import type { DataTrackInfo } from '../types'; import { DataTrackOutgoingPipelineError, DataTrackOutgoingPipelineErrorReason } from './errors'; type Options = { @@ -31,11 +31,7 @@ export default class DataTrackOutgoingPipeline { *processFrame( frame: DataTrackFrame, - ): Throws< - Generator, - | DataTrackOutgoingPipelineError - | DataTrackOutgoingPipelineError - > { + ): Throws, DataTrackOutgoingPipelineError> { const encryptedFrame = this.encryptIfNeeded(frame); try { diff --git a/src/room/data-track/outgoing/types.ts b/src/room/data-track/outgoing/types.ts index 95492ffbd7..85a1a236fa 100644 --- a/src/room/data-track/outgoing/types.ts +++ b/src/room/data-track/outgoing/types.ts @@ -1,5 +1,5 @@ import { type DataTrackHandle } from '../handle'; -import { type DataTrackInfo } from '../track'; +import { type DataTrackInfo } from '../types'; import { type DataTrackPublishError, type DataTrackPublishErrorReason } from './errors'; /** Options for publishing a data track. */ @@ -33,5 +33,4 @@ export type OutputEventSfuUnpublishRequest = { /** Serialized packets are ready to be sent over the transport. */ export type OutputEventPacketsAvailable = { bytes: Uint8Array; - signal?: AbortSignal; }; diff --git a/src/room/data-track/packet/errors.ts b/src/room/data-track/packet/errors.ts index f14d003783..f7b8b34d51 100644 --- a/src/room/data-track/packet/errors.ts +++ b/src/room/data-track/packet/errors.ts @@ -11,7 +11,7 @@ export enum DataTrackDeserializeErrorReason { } export class DataTrackDeserializeError< - Reason extends DataTrackDeserializeErrorReason, + Reason extends DataTrackDeserializeErrorReason = DataTrackDeserializeErrorReason, > extends LivekitReasonedError { readonly name = 'DataTrackDeserializeError'; @@ -73,21 +73,13 @@ export class DataTrackDeserializeError< } } -export type DataTrackDeserializeErrorAll = - | DataTrackDeserializeError - | DataTrackDeserializeError - | DataTrackDeserializeError - | DataTrackDeserializeError - | DataTrackDeserializeError - | DataTrackDeserializeError; - export enum DataTrackSerializeErrorReason { TooSmallForHeader = 0, TooSmallForPayload = 1, } export class DataTrackSerializeError< - Reason extends DataTrackSerializeErrorReason, + Reason extends DataTrackSerializeErrorReason = DataTrackSerializeErrorReason, > extends LivekitReasonedError { readonly name = 'DataTrackSerializeError'; @@ -115,7 +107,3 @@ export class DataTrackSerializeError< ); } } - -export type DataTrackSerializeErrorAll = - | DataTrackSerializeError - | DataTrackSerializeError; diff --git a/src/room/data-track/packet/extensions.ts b/src/room/data-track/packet/extensions.ts index a96d7847b5..c56aa31069 100644 --- a/src/room/data-track/packet/extensions.ts +++ b/src/room/data-track/packet/extensions.ts @@ -20,7 +20,7 @@ export class DataTrackUserTimestampExtension extends DataTrackExtension { static lengthBytes = 8; - private timestamp: bigint; + timestamp: bigint; constructor(timestamp: bigint) { super(); @@ -74,9 +74,9 @@ export class DataTrackE2eeExtension extends DataTrackExtension { static lengthBytes = 13; - private keyIndex: number; + keyIndex: number; - private iv: Uint8Array; /* NOTE: According to the rust implementation, this should be 12 bytes long. */ + iv: Uint8Array; /* NOTE: According to the rust implementation, this should be 12 bytes long. */ constructor(keyIndex: number, iv: Uint8Array) { super(); diff --git a/src/room/data-track/packet/index.ts b/src/room/data-track/packet/index.ts index 74e4d89a19..0351a4dfbf 100644 --- a/src/room/data-track/packet/index.ts +++ b/src/room/data-track/packet/index.ts @@ -26,9 +26,7 @@ import { } from './constants'; import { DataTrackDeserializeError, - type DataTrackDeserializeErrorAll, DataTrackSerializeError, - type DataTrackSerializeErrorAll, DataTrackSerializeErrorReason, } from './errors'; import { DataTrackExtensions } from './extensions'; @@ -167,7 +165,7 @@ export class DataTrackPacketHeader extends Serializable { static fromBinary( input: Input, - ): Throws<[header: DataTrackPacketHeader, byteLength: number], DataTrackDeserializeErrorAll> { + ): Throws<[header: DataTrackPacketHeader, byteLength: number], DataTrackDeserializeError> { const dataView = coerceToDataView(input); if (dataView.byteLength < BASE_HEADER_LEN) { @@ -314,7 +312,7 @@ export class DataTrackPacket extends Serializable { return this.header.toBinaryLengthBytes() + this.payload.byteLength; } - toBinaryInto(dataView: DataView): Throws { + toBinaryInto(dataView: DataView): Throws { let byteIndex = 0; const headerLengthBytes = this.header.toBinaryInto(dataView); byteIndex += headerLengthBytes; @@ -341,7 +339,7 @@ export class DataTrackPacket extends Serializable { static fromBinary( input: Input, - ): Throws<[packet: DataTrackPacket, byteLength: number], DataTrackDeserializeErrorAll> { + ): Throws<[packet: DataTrackPacket, byteLength: number], DataTrackDeserializeError> { const dataView = coerceToDataView(input); const [header, headerByteLength] = DataTrackPacketHeader.fromBinary(dataView); diff --git a/src/room/data-track/packet/serializable.ts b/src/room/data-track/packet/serializable.ts index 40a5e329e1..b8bf538f34 100644 --- a/src/room/data-track/packet/serializable.ts +++ b/src/room/data-track/packet/serializable.ts @@ -1,5 +1,5 @@ import { type Throws } from '../../../utils/throws'; -import { type DataTrackSerializeErrorAll } from './errors'; +import { DataTrackSerializeError } from './errors'; /** An abstract class implementing common behavior related to data track binary serialization. */ export default abstract class Serializable { @@ -7,10 +7,10 @@ export default abstract class Serializable { abstract toBinaryLengthBytes(): number; /** Given a DataView, serialize the instance inside and return the number of bytes written. */ - abstract toBinaryInto(dataView: DataView): Throws; + abstract toBinaryInto(dataView: DataView): Throws; /** Encodes the instance as binary and returns the data as a Uint8Array. */ - toBinary(): Throws { + toBinary(): Throws { const lengthBytes = this.toBinaryLengthBytes(); const output = new ArrayBuffer(lengthBytes); const view = new DataView(output); diff --git a/src/room/data-track/packetizer.ts b/src/room/data-track/packetizer.ts index 0e6a8fe7cd..e0b93b8903 100644 --- a/src/room/data-track/packetizer.ts +++ b/src/room/data-track/packetizer.ts @@ -12,7 +12,7 @@ type PacketizeOptions = { }; export class DataTrackPacketizerError< - Reason extends DataTrackPacketizerReason, + Reason extends DataTrackPacketizerReason = DataTrackPacketizerReason, > extends LivekitReasonedError { readonly name = 'DataTrackPacketizerError'; @@ -78,10 +78,7 @@ export default class DataTrackPacketizer { *packetize( frame: DataTrackFrame, options?: PacketizeOptions, - ): Throws< - Generator, - DataTrackPacketizerError - > { + ): Throws, DataTrackPacketizerError> { const frameNumber = this.frameNumber.getThenIncrement(); const headerParams = { marker: FrameMarker.Inter, diff --git a/src/room/data-track/track-interfaces.ts b/src/room/data-track/track-interfaces.ts new file mode 100644 index 0000000000..412b6632e0 --- /dev/null +++ b/src/room/data-track/track-interfaces.ts @@ -0,0 +1,53 @@ +import type { DataTrackInfo } from './types'; + +function isObject(subject: unknown): subject is object { + return subject !== null && typeof subject === 'object'; +} + +export const TrackSymbol: symbol = Symbol.for('lk.track'); + +export interface ITrack { + readonly trackSymbol: typeof TrackSymbol; +} + +function isTrack(subject: unknown): subject is ITrack { + return isObject(subject) && 'trackSymbol' in subject && subject.trackSymbol === TrackSymbol; +} + +/** An interface representing a track (of any type) which is local and sending data to the SFU. */ +export interface ILocalTrack extends ITrack { + readonly isLocal: true; + + isPublished(): boolean; +} + +// @ts-ignore - Export this in the future when cutting over to new track interfaces more widely +function isLocalTrack(subject: unknown): subject is ILocalTrack { + return isTrack(subject) && 'isLocal' in subject && subject.isLocal === true; +} + +export const RemoteTrackSymbol: symbol = Symbol.for('lk.remote-track'); + +/** An interface representing a track (of any type) which is remote and receiving data from the SFU. */ +export interface IRemoteTrack extends ITrack { + readonly isLocal: false; +} + +// @ts-ignore - Export this in the future when cutting over to new track interfaces more widely +function isRemoteTrack(subject: unknown): subject is IRemoteTrack { + return ( + isTrack(subject) && 'localitySymbol' in subject && subject.localitySymbol === RemoteTrackSymbol + ); +} + +export const DataTrackSymbol: symbol = Symbol.for('lk.data-track'); +/** An interface representing a data track, either local or remote. */ +export interface IDataTrack extends ITrack { + readonly typeSymbol: typeof DataTrackSymbol; + + readonly info: DataTrackInfo; +} + +export function isDataTrack(subject: unknown): subject is IDataTrack { + return isTrack(subject) && 'typeSymbol' in subject && subject.typeSymbol === DataTrackSymbol; +} diff --git a/src/room/data-track/types.ts b/src/room/data-track/types.ts new file mode 100644 index 0000000000..1b04fdf67b --- /dev/null +++ b/src/room/data-track/types.ts @@ -0,0 +1,11 @@ +import { type DataTrackHandle } from './handle'; + +export type DataTrackSid = string; + +/** Information about a published data track. */ +export type DataTrackInfo = { + sid: DataTrackSid; + pubHandle: DataTrackHandle; + name: String; + usesE2ee: boolean; +}; diff --git a/src/utils/abort-signal-polyfill.ts b/src/utils/abort-signal-polyfill.ts new file mode 100644 index 0000000000..6c2ef49931 --- /dev/null +++ b/src/utils/abort-signal-polyfill.ts @@ -0,0 +1,63 @@ +/** + * Implementation of AbortSignal.any + * Creates a signal that will be aborted when any of the given signals is aborted. + * @link https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/any + */ +export function abortSignalAny(signals: Array): AbortSignal { + // Handle empty signals array + if (signals.length === 0) { + const controller = new AbortController(); + return controller.signal; + } + + // Fast path for single signal + if (signals.length === 1) { + return signals[0]; + } + + // Check if any signal is already aborted + for (const signal of signals) { + if (signal.aborted) { + return signal; + } + } + + // Create a new controller for the combined signal + const controller = new AbortController(); + const unlisteners: Array<() => void> = Array(signals.length); + + // Function to clean up all event listeners + const cleanup = () => { + for (const unsubscribe of unlisteners) { + unsubscribe(); + } + }; + + // Add event listeners to each signal + signals.forEach((signal, index) => { + const handler = () => { + controller.abort(signal.reason); + cleanup(); + }; + + signal.addEventListener('abort', handler); + unlisteners[index] = () => signal.removeEventListener('abort', handler); + }); + + return controller.signal; +} + +/** + * Implementation of AbortSignal.timeout + * Creates a signal that will be aborted after the specified timeout. + * @link https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/timeout + */ +export function abortSignalTimeout(ms: number): AbortSignal { + const controller = new AbortController(); + + setTimeout(() => { + controller.abort(new DOMException(`signal timed out after ${ms} ms`, 'TimeoutError')); + }, ms); + + return controller.signal; +} diff --git a/src/utils/subscribeToEvents.ts b/src/utils/subscribeToEvents.ts index 8d9b623c06..292f661197 100644 --- a/src/utils/subscribeToEvents.ts +++ b/src/utils/subscribeToEvents.ts @@ -42,7 +42,11 @@ export function subscribeToEvents< >(eventName: EventName): Promise { // If an event is already buffered which hasn't been processed yet, pull that off the buffer // and use it. - const earliestBufferedEvent = buffers.get(eventName)!.shift(); + const buffer = buffers.get(eventName); + if (!buffer) { + throw new Error(`No events were buffered / received for event "${eventName.toString()}".`); + } + const earliestBufferedEvent = buffer.shift(); if (earliestBufferedEvent) { return earliestBufferedEvent as EventPayload; } @@ -53,6 +57,19 @@ export function subscribeToEvents< const nextEvent = await future.promise; return nextEvent as EventPayload; }, + /** Are there events of the given name which are waiting to be processed? Use this to assert + * that no unexpected events have been emitted. */ + areThereBufferedEvents< + EventPayload extends Parameters[0], + EventName extends EventNames = EventNames, + >(eventName: EventName) { + const buffer = buffers.get(eventName); + if (buffer) { + return buffer.length > 0; + } else { + return false; + } + }, /** Cleanup any lingering subscriptions. */ unsubscribe: () => { for (const [eventName, onEvent] of eventHandlers) { diff --git a/src/utils/throws.ts b/src/utils/throws.ts index 3137523ba9..d1bc9ac5f2 100644 --- a/src/utils/throws.ts +++ b/src/utils/throws.ts @@ -1,4 +1,4 @@ -type Primitives = null | undefined | string | number | bigint | boolean | symbol; +type Primitives = void | null | undefined | string | number | bigint | boolean | symbol; /** * Branded type that encodes possible thrown errors in the return type. diff --git a/throws-transformer/engine.ts b/throws-transformer/engine.ts index 7a95e1dc81..1994d4c791 100644 --- a/throws-transformer/engine.ts +++ b/throws-transformer/engine.ts @@ -44,6 +44,12 @@ export function checkSourceFile( if (rejectResult) { results.push(rejectResult); } + + // Check if this is a reject() call from Promise constructor + const rejectParamResult = checkRejectParameterCall(node, sourceFile, checker); + if (rejectParamResult) { + results.push(rejectParamResult); + } } // Check await expressions @@ -219,6 +225,124 @@ function checkPromiseReject( return null; } +function checkRejectParameterCall( + node: ts.CallExpression, + sourceFile: ts.SourceFile, + checker: ts.TypeChecker, +): CheckResult | null { + // Check if this is a reject() call (not Promise.reject()) + if (!ts.isIdentifier(node.expression) || node.expression.text !== 'reject') { + return null; + } + + // Check if reject is a parameter from a Promise constructor + const promiseConstructor = getPromiseConstructorForReject(node.expression, checker); + if (!promiseConstructor) { + return null; + } + + // Get the function that contains the Promise constructor (not the executor function) + const containingFunction = getContainingFunction(promiseConstructor); + if (!containingFunction) { + return null; + } + + // Check if handled by local catch + const tryCatch = getContainingTryCatch(promiseConstructor); + if (tryCatch) { + return null; // Handled by try-catch + } + + // Get declared error types from the function's return type + const declaredErrors = getDeclaredErrorTypes(containingFunction, checker); + if (declaredErrors === null) { + return null; // Not using Throws<> + } + + // Get the type of the rejected value + if (node.arguments.length === 0) { + return null; // reject() with no argument + } + + const rejectedType = checker.getTypeAtLocation(node.arguments[0]); + const rejectedTypeName = checker.typeToString(rejectedType); + + const isAllowed = isErrorTypeDeclared(checker, rejectedType, declaredErrors); + + if (!isAllowed) { + const start = node.getStart(); + const length = node.getWidth(); + const { line, character } = sourceFile.getLineAndCharacterOfPosition(start); + const declaredNames = declaredErrors + .map((t) => checker.typeToString(t)) + .join(" | "); + + return { + sourceFile, + line: line + 1, + column: character + 1, + start, + length, + + functionName: "reject", + unhandledErrors: [rejectedTypeName], + message: `reject('${rejectedTypeName}') in Promise constructor but it's not declared. Declared: ${declaredNames || "never"}. Add it to Throws<> in your return type.`, + }; + } + + return null; +} + +function getPromiseConstructorForReject( + identifier: ts.Identifier, + checker: ts.TypeChecker, +): ts.NewExpression | null { + // Get the symbol for the identifier + const symbol = checker.getSymbolAtLocation(identifier); + if (!symbol) { + return null; + } + + // Check if it's a parameter + const declarations = symbol.getDeclarations(); + if (!declarations || declarations.length === 0) { + return null; + } + + const declaration = declarations[0]; + if (!ts.isParameter(declaration)) { + return null; + } + + // Check if the parameter is from a function that's passed to Promise constructor + const paramFunction = declaration.parent; + if (!ts.isFunctionExpression(paramFunction) && !ts.isArrowFunction(paramFunction)) { + return null; + } + + // Check if this function is an argument to a Promise constructor + const parent = paramFunction.parent; + + // Handle both direct argument and parenthesized expressions + let newExpr: ts.Node | undefined = parent; + while (newExpr && ts.isParenthesizedExpression(newExpr)) { + newExpr = newExpr.parent; + } + + if (!newExpr || !ts.isNewExpression(newExpr)) { + return null; + } + + // Check if it's a new Promise(...) call + const type = checker.getTypeAtLocation(newExpr.expression); + const typeSymbol = type.getSymbol(); + if (typeSymbol?.getName() === 'Promise' || typeSymbol?.getName() === 'PromiseConstructor') { + return newExpr as ts.NewExpression; + } + + return null; +} + function checkReturnStatement( node: ts.ReturnStatement, sourceFile: ts.SourceFile, @@ -233,11 +357,33 @@ function checkReturnStatement( return null; } - // Get the type of the returned expression - const returnedType = checker.getTypeAtLocation(node.expression); + // Check if this is a .catch() or .then() call on a promise + let returnedErrors: ts.Type[] = []; - // Extract error types from the returned value - const returnedErrors = extractThrowsErrorTypes(returnedType, checker); + if (ts.isCallExpression(node.expression) && + ts.isPropertyAccessExpression(node.expression.expression)) { + const methodName = node.expression.expression.name.text; + + if (methodName === 'catch') { + // For .catch(), check if it handles all errors or rethrows some + returnedErrors = extractErrorsFromPromiseCatch(node.expression, checker); + } else if (methodName === 'then') { + // For .then(), errors propagate from the original promise + const promiseExpr = node.expression.expression.expression; + const promiseType = checker.getTypeAtLocation(promiseExpr); + returnedErrors = extractThrowsErrorTypes(promiseType, checker); + } else { + // Regular method call, extract errors normally + const returnedType = checker.getTypeAtLocation(node.expression); + returnedErrors = extractThrowsErrorTypes(returnedType, checker); + } + } else { + // Get the type of the returned expression + const returnedType = checker.getTypeAtLocation(node.expression); + + // Extract error types from the returned value + returnedErrors = extractThrowsErrorTypes(returnedType, checker); + } if (returnedErrors.length === 0) { return null; // No errors in returned value @@ -279,6 +425,66 @@ function checkReturnStatement( }; } +function extractErrorsFromPromiseCatch( + catchCall: ts.CallExpression, + checker: ts.TypeChecker, +): ts.Type[] { + // Get errors from the original promise (before .catch()) + const promiseExpr = (catchCall.expression as ts.PropertyAccessExpression).expression; + const promiseType = checker.getTypeAtLocation(promiseExpr); + const originalErrors = extractThrowsErrorTypes(promiseType, checker); + + if (originalErrors.length === 0) { + return []; + } + + // Check the catch handler + if (catchCall.arguments.length === 0) { + return originalErrors; // No handler, errors propagate + } + + const handler = catchCall.arguments[0]; + + // If it's a function, check if it rethrows + if (ts.isFunctionExpression(handler) || ts.isArrowFunction(handler)) { + if (!handler.body) { + return []; // No body means it silences errors + } + + // Check if the body contains a throw statement + if (ts.isBlock(handler.body)) { + if (containsThrowStatement(handler.body)) { + // Handler rethrows - find what it throws + const thrownErrors: ts.Type[] = []; + + function visitThrow(node: ts.Node): void { + if (ts.isThrowStatement(node) && node.expression) { + const thrownType = checker.getTypeAtLocation(node.expression); + if (!isAnyOrUnknownType(thrownType)) { + thrownErrors.push(thrownType); + } + } + ts.forEachChild(node, visitThrow); + } + + visitThrow(handler.body); + + // Return the new errors thrown by the handler + return thrownErrors.length > 0 ? thrownErrors : originalErrors; + } else { + // Handler doesn't rethrow, errors are silenced + return []; + } + } else { + // Expression body (arrow function), doesn't throw + return []; + } + } + + // If it's not a function expression, be conservative and assume errors propagate + return originalErrors; +} + function isHandledByLocalCatch( throwNode: ts.ThrowStatement, containingFunction: ts.FunctionLikeDeclaration, @@ -332,6 +538,17 @@ function checkCallExpression( sourceFile: ts.SourceFile, checker: ts.TypeChecker, ): CheckResult | null { + // Check if this IS a .catch() call with a handler that silences errors + if (isCatchCallThatSilencesErrors(node)) { + return null; + } + + // Check if this is a promise-returning call that's being chained or not immediately consumed + // In these cases, the error is contained in the promise and will be handled later + if (isPromiseCallNotImmediatelyConsumed(node, checker)) { + return null; + } + // Get the return type of the call const callType = checker.getTypeAtLocation(node); @@ -360,14 +577,25 @@ function checkCallExpression( return null; } - const propagatedErrors = containingFunction + const propagatedErrorTypes = containingFunction ? getPropagatedErrorTypes(node, containingFunction, sourceFile, checker) - : new Set(); + : []; - // Find unhandled + // Find unhandled - use type compatibility checking for propagated errors const unhandledErrors = errorTypes.filter((errorType) => { const errorName = checker.typeToString(errorType); - return !handledErrors.has(errorName) && !propagatedErrors.has(errorName); + + // Check if handled by catch + if (handledErrors.has(errorName)) { + return false; + } + + // Check if propagated using generic type compatibility + if (isErrorTypeDeclared(checker, errorType, propagatedErrorTypes)) { + return false; + } + + return true; }); if (unhandledErrors.length === 0) { @@ -399,6 +627,81 @@ function checkCallExpression( }; } +function isPromiseCallNotImmediatelyConsumed(node: ts.CallExpression, checker: ts.TypeChecker): boolean { + // Check if this call returns a Promise + const returnType = checker.getTypeAtLocation(node); + const promiseType = extractPromiseType(returnType, checker); + + // If it doesn't return a promise, we should check it + if (!promiseType) { + return false; + } + + // Check the parent context + const parent = node.parent; + + // If it's being chained (e.g., someCall().then(...)) + if (ts.isPropertyAccessExpression(parent)) { + const methodName = parent.name.text; + const promiseMethods = ['then', 'catch', 'finally']; + if (promiseMethods.includes(methodName)) { + return true; + } + } + + // If it's in a variable declaration (e.g., const x = someCall()) + if (ts.isVariableDeclaration(parent)) { + return true; + } + + // If it's in a return statement, we SHOULD check it (it's being consumed) + if (ts.isReturnStatement(parent)) { + return false; + } + + // If it's being awaited, we SHOULD check it (it's being consumed) + if (ts.isAwaitExpression(parent)) { + return false; + } + + // Default: if returning a promise and not explicitly consumed, skip checking + // This handles cases like assignment, being passed as an argument, etc. + return true; +} + +function isCatchCallThatSilencesErrors(node: ts.CallExpression): boolean { + // Check if this is a .catch() method call + if (!ts.isPropertyAccessExpression(node.expression)) { + return false; + } + + if (node.expression.name.text !== 'catch') { + return false; + } + + // Check if the catch handler silences errors + if (node.arguments.length === 0) { + return false; // No handler + } + + const handler = node.arguments[0]; + if (ts.isFunctionExpression(handler) || ts.isArrowFunction(handler)) { + if (!handler.body) { + return true; // No body means it silences + } + + if (ts.isBlock(handler.body)) { + // Check if it contains a throw statement + return !containsThrowStatement(handler.body); + } else { + // Expression body doesn't throw + return true; + } + } + + return false; +} + function extractThrowsErrorTypes( type: ts.Type, checker: ts.TypeChecker, @@ -488,11 +791,85 @@ function isErrorTypeDeclared( if (baseTypes.some((base) => checker.typeToString(base) === declaredName)) { return true; } + + // Check if types share the same base generic type with compatible type arguments + if (areGenericTypesCompatible(checker, thrownType, declared)) { + return true; + } } return false; } +function areGenericTypesCompatible( + checker: ts.TypeChecker, + source: ts.Type, + target: ts.Type, +): boolean { + // Cast to TypeReference to access type arguments + const sourceRef = source as ts.TypeReference; + const targetRef = target as ts.TypeReference; + + // Both must be type references with type arguments + if (!sourceRef.typeArguments || !targetRef.typeArguments) { + return false; + } + + // Check if they have the same base symbol (e.g., both are TypedError<...>) + const sourceSymbol = source.getSymbol(); + const targetSymbol = target.getSymbol(); + + if (!sourceSymbol || !targetSymbol || sourceSymbol !== targetSymbol) { + return false; + } + + // Both have the same generic base, now check if type arguments are compatible + if (sourceRef.typeArguments.length !== targetRef.typeArguments.length) { + return false; + } + + for (let i = 0; i < sourceRef.typeArguments.length; i++) { + const sourceArg = sourceRef.typeArguments[i]; + const targetArg = targetRef.typeArguments[i]; + + if (!isTypeArgumentAssignable(checker, sourceArg, targetArg)) { + return false; + } + } + + return true; +} + +function isTypeArgumentAssignable( + checker: ts.TypeChecker, + source: ts.Type, + target: ts.Type, +): boolean { + const sourceName = checker.typeToString(source); + const targetName = checker.typeToString(target); + + // Exact match + if (sourceName === targetName) { + return true; + } + + // Check if target is a union and source is one of its members + if (target.isUnion() && target.types.length > 1) { + return target.types.some(t => { + const tName = checker.typeToString(t); + return tName === sourceName; + }); + } + + // Check if source is an enum member and target is the enum + // e.g., source is "Types.Foo" and target is "Types" + if (sourceName.includes('.') && sourceName.startsWith(targetName + '.')) { + return true; + } + + return false; +} + function getBaseTypes(checker: ts.TypeChecker, type: ts.Type): ts.Type[] { const bases: ts.Type[] = []; @@ -867,26 +1244,17 @@ function getPropagatedErrorTypes( func: ts.FunctionLikeDeclaration, sourceFile: ts.SourceFile, checker: ts.TypeChecker, -): Set { - const propagated = new Set(); - - if (!func.type) { return propagated; } +): ts.Type[] { + if (!func.type) { return []; } // If `node` is in a try/catch, then the errors propegated are the errors that the catch itself throws const tryCatch = getContainingTryCatch(node); if (tryCatch?.catchClause) { - const thrownErrorTypes = getTryCatchThrownErrors(tryCatch, sourceFile, checker); - return new Set(thrownErrorTypes.map(e => checker.typeToString(e))); + return getTryCatchThrownErrors(tryCatch, sourceFile, checker); } const returnType = checker.getTypeFromTypeNode(func.type); - const errorTypes = extractThrowsErrorTypes(returnType, checker); - - for (const errorType of errorTypes) { - propagated.add(checker.typeToString(errorType)); - } - - return propagated; + return extractThrowsErrorTypes(returnType, checker); } function getFunctionName(node: ts.CallExpression): string {