diff --git a/.gitignore b/.gitignore index 5e5ff962..d5c23c20 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ DerivedData/ .*.history.json .claude .build +.swiftpm diff --git a/.parity/check-swift-mapper-parity.mjs b/.parity/check-swift-mapper-parity.mjs index ae930879..58bc9faf 100644 --- a/.parity/check-swift-mapper-parity.mjs +++ b/.parity/check-swift-mapper-parity.mjs @@ -1,4 +1,5 @@ import { app } from 'electron' +import * as fs from 'node:fs/promises' import * as path from 'node:path' import { fileURLToPath } from 'node:url' @@ -31,7 +32,7 @@ const { const args = parseArgs(process.argv.slice(2)) const referenceRoot = path.resolve(args.get('reference-root') ?? path.join(repoRoot, '.parity/platform-imessage-main')) -const defaultReferenceIMessageNodePath = path.join(referenceRoot, 'binaries', `${process.platform}-${process.arch}`, 'IMessage.node') +const defaultReferenceIMessageNodePath = path.join(referenceRoot, 'binaries', `${process.platform}-${process.arch}`, 'SwiftServer.node') const referenceIMessageNodePath = args.get('reference-swift-server-node') ?? defaultReferenceIMessageNodePath const referenceBinariesDirPath = args.get('reference-binaries-dir') ?? path.dirname(path.dirname(referenceIMessageNodePath)) @@ -229,21 +230,34 @@ function searchTermsFromMessage(message) { } if (childRole) { - await runAPIChild({ - role: childRole, + try { + await runAPIChild({ + role: childRole, + repoRoot, + referenceAPIPath: args.get('reference-api-bundle'), + referenceBinariesDirPath, + }) + process.exit(process.exitCode ?? 0) + } catch (error) { + console.error(error instanceof Error ? error.stack ?? error.message : String(error)) + app.exit(1) + process.exit(1) + } +} + +let referenceAPIPath +try { + referenceAPIPath = await ensureReferenceAPI({ + args, repoRoot, - referenceAPIPath: args.get('reference-api-bundle'), + referenceRoot, referenceBinariesDirPath, }) - process.exit(0) +} catch (error) { + console.error(error instanceof Error ? error.stack ?? error.message : String(error)) + app.exit(1) + process.exit(1) } - -const referenceAPIPath = await ensureReferenceAPI({ - args, - repoRoot, - referenceRoot, - referenceBinariesDirPath, -}) const childAPIs = [ spawnAPIChild({ role: 'current', @@ -500,7 +514,7 @@ try { byDiff[failure.details] = (byDiff[failure.details] ?? 0) + 1 } - console.log(JSON.stringify({ + const summary = { chatLimit: formatLimit(chatLimit), skipChats, messageLimitPerChat: formatLimit(messageLimit), @@ -521,7 +535,15 @@ try { byDiff, perfDeltas: summarizePerfDeltas(), failures, - }, null, 2)) + } + const summaryJSON = JSON.stringify(summary, null, 2) + const outputJSONPath = args.get('output-json') + if (outputJSONPath) { + await fs.writeFile(outputJSONPath, `${summaryJSON}\n`) + } + if (!args.has('no-stdout-json')) { + console.log(summaryJSON) + } process.exitCode = failures.length === 0 ? 0 : 1 } finally { diff --git a/.parity/parity-child-processes.mjs b/.parity/parity-child-processes.mjs index 8a5b5b34..8dc0ccce 100644 --- a/.parity/parity-child-processes.mjs +++ b/.parity/parity-child-processes.mjs @@ -108,6 +108,9 @@ export async function runAPIChild(options) { writeIPC({ id: request.id, ok: false, error: serializeError(result.error), ms: result.ms }) } } + } catch (error) { + writeIPC({ type: 'startup-error', error: serializeError(error) }) + process.exitCode = 1 } finally { await Promise.resolve(api?.dispose?.()).catch(() => {}) if (dataDirPath) await fs.rm(dataDirPath, { recursive: true, force: true }).catch(() => {}) @@ -177,6 +180,10 @@ export function spawnAPIChild({ readyResolve() return } + if (message.type === 'startup-error') { + readyReject(deserializeError(message.error)) + return + } const request = pending.get(message.id) if (!request) return diff --git a/.parity/parity-utils.mjs b/.parity/parity-utils.mjs index b07063cc..fcc926d9 100644 --- a/.parity/parity-utils.mjs +++ b/.parity/parity-utils.mjs @@ -46,6 +46,43 @@ export function exec(command, commandArgs, cwd) { execFileSync(command, commandArgs, { cwd, stdio: 'inherit' }) } +async function ensureReferenceDependencies(referenceRoot) { + if (!await pathExists(path.join(referenceRoot, 'node_modules'))) { + exec('yarn', [], referenceRoot) + } +} + +async function findReferenceNativeModule(referenceBinariesDirPath) { + const archBinariesDirPath = path.join(referenceBinariesDirPath, `${process.platform}-${process.arch}`) + for (const fileName of ['IMessage.node', 'SwiftServer.node']) { + const candidate = path.join(archBinariesDirPath, fileName) + if (await pathExists(candidate)) return candidate + } + return undefined +} + +async function ensureReferenceNativeModule({ args, referenceRoot, referenceBinariesDirPath }) { + if (args.get('reference-swift-server-node')) return + if (await findReferenceNativeModule(referenceBinariesDirPath)) return + + if (args.has('skip-reference-rebuild') && !args.has('rebuild-reference')) { + throw new Error( + `Reference native module is missing under ${path.join(referenceBinariesDirPath, `${process.platform}-${process.arch}`)}. ` + + 'Run again without --skip-reference-rebuild, or pass --reference-binaries-dir to a directory containing the reference Swift .node binary.', + ) + } + + await ensureReferenceDependencies(referenceRoot) + exec('bun', ['build:swift', '--standalone'], referenceRoot) + + if (!await findReferenceNativeModule(referenceBinariesDirPath)) { + throw new Error( + `Reference Swift build finished, but no IMessage.node or SwiftServer.node was found under ` + + `${path.join(referenceBinariesDirPath, `${process.platform}-${process.arch}`)}.`, + ) + } +} + export async function readDefaultReferenceRef(repoRoot) { const refFile = path.join(repoRoot, '.parity/REFERENCE_REF') try { @@ -73,9 +110,9 @@ export async function ensureReferenceAPI({ didCreateReferenceRoot = true } if (didCreateReferenceRoot) { - exec('yarn', [], referenceRoot) - exec('bun', ['build:swift', '--standalone'], referenceRoot) + await ensureReferenceDependencies(referenceRoot) } + await ensureReferenceNativeModule({ args, referenceRoot, referenceBinariesDirPath }) if (!args.has('skip-reference-rebuild') || args.has('rebuild-reference') || !await pathExists(bundlePath)) { const binariesDirPathLiteral = JSON.stringify(referenceBinariesDirPath) const buildBanner = `globalThis.texts={IS_DEV:true,isLoggingEnabled:false,log(){},error(){},constants:{USER_AGENT:'platform-imessage-parity',APP_VERSION:'1.0.0'},Sentry:{captureException(){},captureMessage(){},startTransaction(){}},async trackPlatformEvent(){},getBinariesDirPath(){return ${binariesDirPathLiteral}},fetch:globalThis.fetch,fetchStream:undefined,createHttpClient:undefined,nativeFetch:undefined,nativeFetchStream:undefined,runWorker:undefined,forkChildProcess:undefined,getOriginalObject:undefined,openBrowserWindow:undefined};` diff --git a/Package.resolved b/Package.resolved index 21e07d62..a5fbe2f6 100644 --- a/Package.resolved +++ b/Package.resolved @@ -18,6 +18,15 @@ "version" : "2.2.0" } }, + { + "identity" : "grdb.swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/groue/GRDB.swift.git", + "state" : { + "revision" : "2cf6c756e1e5ef6901ebae16576a7e4e4b834622", + "version" : "6.29.3" + } + }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 5be327ac..f967b873 100644 --- a/Package.swift +++ b/Package.swift @@ -15,6 +15,7 @@ var products: [Product] = [ targets: ["IMessage"] ), .executable(name: "imessage-cli", targets: ["IMessageCLI"]), + .executable(name: "IMessagePerfBench", targets: ["IMessagePerfBench"]), ] var dependencies: [Package.Dependency] = [ @@ -24,6 +25,7 @@ var dependencies: [Package.Dependency] = [ .package(url: "https://github.com/apple/swift-collections.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"), .package(url: "https://github.com/apple/swift-argument-parser", from: "1.6.1"), + .package(url: "https://github.com/groue/GRDB.swift.git", from: "6.29.3"), .package(url: "https://github.com/swiftlang/swift-syntax.git", exact: "603.0.0-prerelease-2025-10-30"), ] @@ -76,13 +78,21 @@ var targets: [Target] = [ dependencies: ["SQLite"], path: "src/IMessage/Sources/SQLiteTests" ), + .testTarget( + name: "IMDatabaseTests", + dependencies: [ + "IMDatabase", + .product(name: "GRDB", package: "GRDB.swift"), + ], + path: "src/IMessage/Sources/IMDatabaseTests" + ), .target( name: "IMDatabase", dependencies: [ .product(name: "Logging", package: "swift-log"), .product(name: "AsyncAlgorithms", package: "swift-async-algorithms"), .product(name: "Collections", package: "swift-collections"), - "SQLite", + .product(name: "GRDB", package: "GRDB.swift"), "ExceptionCatcher", "IMessageCore", ], @@ -98,6 +108,16 @@ var targets: [Target] = [ path: "src/IMessage/Sources/IMessageCLI", plugins: ["GenerateIMessageCLIVersionPlugin"] ), + .executableTarget( + name: "IMessagePerfBench", + dependencies: [ + "IMDatabase", + "IMessage", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ], + path: "src/IMessage/Sources/IMessagePerfBench", + exclude: ["README.md"] + ), .plugin( name: "GenerateIMessageCLIVersionPlugin", capability: .buildTool(), diff --git a/package.json b/package.json index 5e613d0c..55187b31 100644 --- a/package.json +++ b/package.json @@ -22,6 +22,7 @@ "lint:js": "yarn eslint src --ext ts,tsx,js,jsx --cache", "cli:js": "env NODE_OPTIONS=\"--force-node-api-uncaught-exceptions-policy=true\" electron cli.compiled.mjs", "cli": "swift run imessage-cli", + "perf:imessage": "node scripts/imessage-perf.mjs", "swift-mapper-parity": "yarn build:swift-mapper-parity && env NODE_OPTIONS=\"--force-node-api-uncaught-exceptions-policy=true\" electron .parity/check-swift-mapper-parity.compiled.mjs", "build:cli:release": "sh -c 'swift build -c release --product imessage-cli >/dev/null && bin_path=$(swift build -c release --product imessage-cli --show-bin-path) && printf \"%s/imessage-cli\\n\" \"$bin_path\"'", "build:cli:js": "bun build src/cli/index.ts --target=node --format=esm --external electron --external @textshq/platform-test-lib --outfile=cli.compiled.mjs", diff --git a/scripts/imessage-perf.mjs b/scripts/imessage-perf.mjs new file mode 100755 index 00000000..86fe62a8 --- /dev/null +++ b/scripts/imessage-perf.mjs @@ -0,0 +1,366 @@ +#!/usr/bin/env node +import { execFileSync, spawnSync } from 'node:child_process' +import * as fsSync from 'node:fs' +import * as fs from 'node:fs/promises' +import * as os from 'node:os' +import * as path from 'node:path' +import { fileURLToPath } from 'node:url' + +const scriptDir = path.dirname(fileURLToPath(import.meta.url)) +const repoRoot = path.resolve(scriptDir, '..') +const productName = 'IMessagePerfBench' + +const args = parseArgs(process.argv.slice(2)) +const noBuild = args.has('no-build') +const debug = args.has('debug') +const jsonOutput = args.has('json') +const withParity = args.has('with-parity') || args.has('parity-only') +const parityOnly = args.has('parity-only') +const configuration = debug ? 'debug' : 'release' + +if (args.has('help')) { + printHelp() + process.exit(0) +} + +const output = { + swift: null, + parity: null, +} + +if (!parityOnly) { + if (!noBuild) buildSwiftBench() + output.swift = runSwiftBench() +} + +if (withParity) { + output.parity = await runParity() +} + +if (jsonOutput) { + console.log(JSON.stringify(output, null, 2)) +} else { + if (output.swift) printSwiftReport(output.swift) + if (output.parity) printParityReport(output.parity) +} + +function parseArgs(argv) { + const parsed = new Map() + for (let index = 0; index < argv.length; index += 1) { + const arg = argv[index] + const equalsMatch = arg.match(/^--([^=]+)=(.*)$/) + if (equalsMatch) { + parsed.set(equalsMatch[1], equalsMatch[2]) + continue + } + const flagMatch = arg.match(/^--(.+)$/) + if (!flagMatch) continue + + const name = flagMatch[1] + const next = argv[index + 1] + if (next && !next.startsWith('--') && optionTakesValue(name)) { + parsed.set(name, next) + index += 1 + } else { + parsed.set(name, '1') + } + } + return parsed +} + +function optionTakesValue(name) { + return [ + 'messages-dir', + 'iterations', + 'warmups', + 'max-chats', + 'message-limit', + 'api-thread-samples', + 'search-query', + 'api-timeout-ms', + 'parity-timeout-ms', + 'parity-max-chats', + 'parity-max-messages-per-chat', + 'get-message-samples', + 'search-samples', + 'progress-every', + 'reference-root', + 'reference-ref', + 'reference-swift-server-node', + 'reference-binaries-dir', + ].includes(name) +} + +function buildSwiftBench() { + run('swift', ['build', '-c', configuration, '--product', productName], { stdio: 'inherit' }) +} + +function runSwiftBench() { + const binDir = execFileSync('swift', [ + 'build', + '-c', + configuration, + '--product', + productName, + '--show-bin-path', + ], { cwd: repoRoot, encoding: 'utf8' }).trim() + const binPath = path.join(binDir, productName) + const swiftArgs = ['--format', 'json'] + + for (const name of [ + 'messages-dir', + 'iterations', + 'warmups', + 'max-chats', + 'message-limit', + 'api-thread-samples', + 'search-query', + ]) { + if (args.has(name)) swiftArgs.push(`--${name}`, args.get(name)) + } + for (const name of ['create-indexes', 'sql-only', 'api-only']) { + if (args.has(name)) swiftArgs.push(`--${name}`) + } + + const result = spawnSync(binPath, swiftArgs, { + cwd: repoRoot, + encoding: 'utf8', + }) + if (result.error) { + throw result.error + } + if (result.status !== 0) { + process.stderr.write(result.stdout ?? '') + process.stderr.write(result.stderr ?? '') + process.exit(result.status ?? 1) + } + + try { + return JSON.parse(result.stdout) + } catch (error) { + process.stderr.write(result.stdout ?? '') + throw new Error(`Could not parse ${productName} JSON output: ${error.message}`) + } +} + +async function runParity() { + if (!noBuild) { + run('yarn', ['build:swift-mapper-parity'], { stdio: 'inherit' }) + } + + const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'imessage-perf-parity-')) + const outputJSONPath = path.join(tempDir, 'parity.json') + const parityArgs = [ + '.parity/check-swift-mapper-parity.compiled.mjs', + `--output-json=${outputJSONPath}`, + '--no-stdout-json', + `--max-chats=${args.get('parity-max-chats') ?? args.get('max-chats') ?? '5'}`, + `--max-messages-per-chat=${args.get('parity-max-messages-per-chat') ?? args.get('message-limit') ?? '20'}`, + `--get-message-samples=${args.get('get-message-samples') ?? '0'}`, + `--search-samples=${args.get('search-samples') ?? '0'}`, + `--progress-every=${args.get('progress-every') ?? '1'}`, + `--call-timeout-ms=${args.get('api-timeout-ms') ?? '5000'}`, + ] + + for (const name of [ + 'reference-root', + 'reference-ref', + 'reference-swift-server-node', + 'reference-binaries-dir', + ]) { + if (args.has(name)) parityArgs.push(`--${name}=${args.get(name)}`) + } + for (const name of ['skip-reference-rebuild', 'rebuild-reference', 'forward-child-output']) { + if (args.has(name)) parityArgs.push(`--${name}`) + } + + const parityTimeoutMs = Number.parseInt(args.get('parity-timeout-ms') ?? '120000', 10) + const result = spawnSync(resolveElectron(), parityArgs, { + cwd: repoRoot, + encoding: 'utf8', + stdio: ['ignore', 'pipe', 'pipe'], + timeout: Number.isFinite(parityTimeoutMs) ? parityTimeoutMs : undefined, + killSignal: 'SIGTERM', + }) + if (result.error?.code === 'ETIMEDOUT') { + process.stderr.write(`Parity run timed out after ${parityTimeoutMs}ms. Increase --parity-timeout-ms or run the parity command directly with --forward-child-output.\n`) + process.exit(124) + } + if (result.error) { + throw result.error + } + if (result.stdout) process.stderr.write(result.stdout) + if (result.status !== 0) { + process.stderr.write(result.stderr ?? '') + process.exit(result.status ?? 1) + } + + try { + return JSON.parse(await fs.readFile(outputJSONPath, 'utf8')) + } finally { + await fs.rm(tempDir, { recursive: true, force: true }).catch(() => {}) + } +} + +function run(command, commandArgs, options = {}) { + const result = spawnSync(command, commandArgs, { + cwd: repoRoot, + encoding: 'utf8', + ...options, + }) + if (result.error) { + throw result.error + } + if (result.status !== 0) { + if (result.stdout) process.stderr.write(result.stdout) + if (result.stderr) process.stderr.write(result.stderr) + process.exit(result.status ?? 1) + } + return result +} + +function resolveBin(name) { + const localPath = path.join(repoRoot, 'node_modules', '.bin', name) + return fsSync.existsSync(localPath) ? localPath : name +} + +function resolveElectron() { + const macElectronPath = path.join( + repoRoot, + 'node_modules', + 'electron', + 'dist', + 'Electron.app', + 'Contents', + 'MacOS', + 'Electron', + ) + return fsSync.existsSync(macElectronPath) ? macElectronPath : resolveBin('electron') +} + +function printSwiftReport(report) { + console.log(bold('iMessage performance benchmarks')) + console.log(dim(`messages dir: ${shortenHome(report.metadata.messagesDir)}`)) + console.log(dim(`iterations: ${report.metadata.iterations}, warmups: ${report.metadata.warmups}`)) + console.log() + printBenchTable('SQL hot paths', report.sql.results) + console.log() + printBenchTable('Platform API', report.api.results) +} + +function printBenchTable(title, results) { + console.log(bold(title)) + if (!results?.length) { + console.log(dim('skipped')) + return + } + printTable([ + ['name', 'rows', 'avg ms', 'p50 ms', 'p95 ms', 'min ms', 'max ms'], + ...results.map(result => [ + result.name, + String(result.resultCount), + formatMS(result.averageMS), + formatMS(result.p50MS), + formatMS(result.p95MS), + formatMS(result.minMS), + formatMS(result.maxMS), + ]), + ], { numericColumns: new Set([1, 2, 3, 4, 5, 6]) }) +} + +function printParityReport(report) { + console.log() + console.log(bold('Parity API comparison')) + console.log(dim(`checked chats: ${report.chatsChecked}, getThreads pages: ${report.getThreadsPagesChecked}, getMessages pages: ${report.getMessagesPagesChecked}`)) + if (report.strictFailures > 0) { + console.log(red(`strict failures: ${report.strictFailures}`)) + } else { + console.log(green('strict failures: 0')) + } + + const byPhase = report.perfDeltas?.byPhase ?? {} + const rows = Object.entries(byPhase).map(([phase, value]) => [ + phase, + String(value.samples), + formatMS(value.avgCurrentMs), + formatMS(value.avgReferenceMs), + formatMS(value.avgDeltaMs), + value.aggregateRatio == null ? '-' : value.aggregateRatio.toFixed(3), + ]) + if (rows.length) { + printTable([ + ['phase', 'samples', 'current avg', 'reference avg', 'delta avg', 'ratio'], + ...rows, + ], { numericColumns: new Set([1, 2, 3, 4, 5]) }) + } +} + +function printTable(rows, { numericColumns = new Set() } = {}) { + const widths = [] + for (const row of rows) { + row.forEach((cell, index) => { + widths[index] = Math.max(widths[index] ?? 0, visibleLength(cell)) + }) + } + rows.forEach((row, rowIndex) => { + const line = row.map((cell, index) => { + const padding = ' '.repeat(widths[index] - visibleLength(cell)) + return numericColumns.has(index) ? `${padding}${cell}` : `${cell}${padding}` + }).join(' ') + console.log(rowIndex === 0 ? dim(line) : line) + }) +} + +function printHelp() { + console.log(`Usage: yarn perf:imessage [options] + +Runs backend-agnostic IMDatabase hot-path benchmarks and PlatformAPI getThreads/getMessages timings. + +Common options: + --iterations Measured iterations per case + --warmups Warmup iterations per case + --max-chats Chats sampled by SQL benchmarks + --message-limit Messages sampled per chat + --api-thread-samples Threads sampled by PlatformAPI.getMessages + --sql-only Skip PlatformAPI benchmarks + --api-only Skip SQL hot-path benchmarks + --create-indexes Ask IMDatabase to create optional read indexes + --with-parity Also run the current-vs-reference parity script + --parity-timeout-ms Overall timeout for --with-parity + --json Emit machine-readable JSON + --no-build Reuse existing built artifacts +`) +} + +function formatMS(value) { + return Number(value).toFixed(3) +} + +function shortenHome(value) { + return value.replace(os.homedir(), '~') +} + +function visibleLength(value) { + return String(value).replace(/\u001b\[[0-9;]*m/g, '').length +} + +function color(open, close, value) { + if (!process.stdout.isTTY) return value + return `${open}${value}${close}` +} + +function bold(value) { + return color('\u001b[1m', '\u001b[0m', value) +} + +function dim(value) { + return color('\u001b[2m', '\u001b[0m', value) +} + +function red(value) { + return color('\u001b[31m', '\u001b[0m', value) +} + +function green(value) { + return color('\u001b[32m', '\u001b[0m', value) +} diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Accounts.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Accounts.swift index 7224f7c7..432458ed 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Accounts.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Accounts.swift @@ -1,14 +1,12 @@ +import GRDB + public extension IMDatabase { func accountLogins() throws -> [String] { - let statement = try cachedStatement(forEscapedSQL: """ - SELECT DISTINCT account_login - FROM chat - """) - - try statement.reset() - - return try statement.compactMapRowsUntilDone { row in - try row[0].optional(String.self) + try read { db in + try Row.fetchAll(db, sql: """ + SELECT DISTINCT account_login + FROM chat + """).compactMap { $0[0] as String? } } } } diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Attachments.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Attachments.swift index a743d207..5104de67 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Attachments.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Attachments.swift @@ -1,6 +1,6 @@ import Collections +import GRDB import Logging -import SQLite import IMessageCore private let log = Logger(imessageLabel: "imdb.db") @@ -14,14 +14,13 @@ LEFT JOIN attachment a ON a.ROWID = maj.attachment_id extension IMDatabase { func hydrateAttachments(for message: inout Message) throws { - let statement = try cachedStatement(forEscapedSQL: """ - \(attachmentQuerySharedPrologue) - WHERE m.guid = ? - """).reset() - try statement.bind(message.guid) - - let attachments = try statement.compactMapRowsUntilDone { row in - try Attachment(row: row) + let attachments = try read { db in + try Row.fetchAll(db, sql: """ + \(attachmentQuerySharedPrologue) + WHERE m.guid = ? + """, arguments: [message.guid]).compactMap { row in + try Attachment(row: row) + } } message.attachments = attachments #if DEBUG @@ -32,44 +31,45 @@ extension IMDatabase { func hydrateAttachments(for messages: inout OrderedDictionary) throws { let messageRowIDs = messages.keys.map(String.init) - let statement = try Statement.prepare(escapedSQL: """ - \(attachmentQuerySharedPrologue) - WHERE m.ROWID IN (\(messageRowIDs.joined(separator: ","))) - """, for: database) + try read { db in + let rows = try Row.fetchAll(db, sql: """ + \(attachmentQuerySharedPrologue) + WHERE m.ROWID IN (\(messageRowIDs.joined(separator: ","))) + """) + for row in rows { + let messageRowID = row.requiredInt(at: 0) - try statement.stepUntilDone { row in - let messageRowID = try row[0].expect(Int.self) + guard messages[messageRowID] != nil else { + assertionFailure() + continue + } - guard messages[messageRowID] != nil else { - assertionFailure() - return - } + if messages[messageRowID]!.attachments == nil { + messages[messageRowID]!.attachments = [] + } - if messages[messageRowID]!.attachments == nil { - messages[messageRowID]!.attachments = [] - } + guard let attachment = try Attachment(row: row) else { + continue + } - guard let attachment = try Attachment(row: row) else { - return + messages[messageRowID]!.attachments!.append(attachment) } - - messages[messageRowID]!.attachments!.append(attachment) } } } extension Attachment { - init?(row: borrowing Row) throws { + init?(row: Row) throws { // (skipping `m.ROWID`) - guard let attachmentRowID = try row[1].optionalConverting(Int.self) else { + guard let attachmentRowID = row.optionalInt(at: 1) else { return nil } - let attachmentGUID = try GUID(row[2].expect(String.self)) - let fileName = try row[3].optionalConverting(String.self) - let transferName = try row[4].optionalConverting(String.self) - let isSticker = try row[5].looseBool() - let transferState = try Attachment.IMFileTransferState(rawValue: row[6].expectConverting(Int.self)) - let uti = try row[7].optionalConverting(String.self) + let attachmentGUID = GUID(row.requiredString(at: 2)) + let fileName = row.optionalString(at: 3) + let transferName = row.optionalString(at: 4) + let isSticker = row.looseBool(at: 5) + let transferState = Attachment.IMFileTransferState(rawValue: row.requiredInt(at: 6)) + let uti = row.optionalString(at: 7) self = Attachment( id: attachmentRowID, diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Chats.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Chats.swift index 620e0fc0..f59cbdb9 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Chats.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Chats.swift @@ -1,3 +1,4 @@ +import GRDB import IMessageCore import Logging @@ -6,19 +7,16 @@ private let log = Logger(label: "imdb.chats") public extension IMDatabase { // TODO: replace with overload that takes `GUID` func chat(withGUID chatGUID: String) throws -> Chat? { - let statement = try cachedStatement(forEscapedSQL: """ - SELECT ROWID, display_name, service_name - FROM chat - WHERE guid = ? - """) - - try statement.reset() - try statement.bind(chatGUID) - - let chats = try statement.mapRowsUntilDone { row in - let displayName = try row[1].optional(String.self)?.nonEmpty - let serviceName = try Chat.ServiceName(rawValue: row[2].expect(String.self)) - return try Chat(id: row[0].expect(Int.self), guid: GUID(chatGUID), displayName: displayName, serviceName: serviceName) + let chats = try read { db in + try Row.fetchAll(db, sql: """ + SELECT ROWID, display_name, service_name + FROM chat + WHERE guid = ? + """, arguments: [chatGUID]).map { row in + let displayName = row.optionalString(at: 1)?.nonEmpty + let serviceName = Chat.ServiceName(rawValue: row.requiredString(at: 2)) + return Chat(id: row.requiredInt(at: 0), guid: GUID(chatGUID), displayName: displayName, serviceName: serviceName) + } } if chats.count > 1 { @@ -32,41 +30,36 @@ public extension IMDatabase { } func chats() throws -> [Chat] { - let statement = try cachedStatement(forEscapedSQL: """ - SELECT ROWID, guid, display_name, service_name - FROM chat - """) - - try statement.reset() - - return try statement.mapRowsUntilDone { row -> Chat? in - let id = try row[0].expect(Int.self) - guard let guid = try row[1].optional(String.self) else { - log.error("chat \(id) has no GUID, very spooky. dropping it on the ground") - return nil + try read { db in + try Row.fetchAll(db, sql: """ + SELECT ROWID, guid, display_name, service_name + FROM chat + """).compactMap { row -> Chat? in + let id = row.requiredInt(at: 0) + guard let guid = row.optionalString(at: 1) else { + log.error("chat \(id) has no GUID, very spooky. dropping it on the ground") + return nil + } + let displayName = row.optionalString(at: 2)?.nonEmpty + let serviceName = Chat.ServiceName(rawValue: row.optionalString(at: 3) ?? "NONE") + return Chat(id: id, guid: GUID(guid), displayName: displayName, serviceName: serviceName) } - let displayName = try row[2].optional(String.self)?.nonEmpty - let serviceName = try Chat.ServiceName(rawValue: row[3].optional(String.self) ?? "NONE") - return Chat(id: id, guid: GUID(guid), displayName: displayName, serviceName: serviceName) - }.compactMap(\.self) + } } // this doesn't include the user themselves, just everyone else in the group chat, // UNLESS the user went out of their way to redundantly add themselves, which is possible when initially creating the chat func handles(inChatWithGUID chatGUID: String) throws -> [Handle] { - let statement = try cachedStatement(forEscapedSQL: """ - SELECT handle.ROWID, handle.id - FROM chat - INNER JOIN chat_handle_join ON chat_handle_join.chat_id = chat.ROWID - INNER JOIN handle ON handle.ROWID = chat_handle_join.handle_id - WHERE chat.guid = ? - """) - - try statement.reset() - try statement.bind(chatGUID) - - return try statement.mapRowsUntilDone { row in - try Handle(rowid: row[0].expect(Int.self), id: row[1].expect(String.self)) + try read { db in + try Row.fetchAll(db, sql: """ + SELECT handle.ROWID, handle.id + FROM chat + INNER JOIN chat_handle_join ON chat_handle_join.chat_id = chat.ROWID + INNER JOIN handle ON handle.ROWID = chat_handle_join.handle_id + WHERE chat.guid = ? + """, arguments: [chatGUID]).map { row in + Handle(rowid: row.requiredInt(at: 0), id: row.requiredString(at: 1)) + } } } } diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedMessages.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedMessages.swift index a86e822f..cfbcd0e0 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedMessages.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedMessages.swift @@ -1,5 +1,7 @@ +import Collections import Foundation -import SQLite +import IMessageCore +import GRDB private let messageJoins = """ LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID @@ -26,17 +28,15 @@ LEFT JOIN handle AS oh ON m.other_handle = oh.ROWID public extension IMDatabase { func lastMessageRowID() throws -> Int { - let statement = try cachedStatement(forEscapedSQL: "SELECT seq FROM sqlite_sequence WHERE name = 'message'").reset() - return try statement.compactMapRowsUntilDone { row in - try row[0].optionalConverting(Int.self) - }.first ?? 0 + try read { db in + try fetchOneCached(Int.self, db: db, sql: "SELECT seq FROM sqlite_sequence WHERE name = 'message'") ?? 0 + } } func maxMessageDateRead() throws -> Date { - let statement = try cachedStatement(forEscapedSQL: "SELECT MAX(date_read) FROM message").reset() - let nanoseconds = try statement.compactMapRowsUntilDone { row in - try row[0].optionalConverting(Int.self) - }.first ?? 0 + let nanoseconds = try read { db in + try fetchOneCached(Int.self, db: db, sql: "SELECT MAX(date_read) FROM message") ?? 0 + } guard nanoseconds > 0, nanoseconds < .max else { return Date(nanosecondsSinceReferenceDate: 0) @@ -45,40 +45,64 @@ public extension IMDatabase { return Date(nanosecondsSinceReferenceDate: nanoseconds) } + func messageUpdateCursorSnapshot() throws -> (lastRowID: Int, lastDateRead: Date, lastDateEdited: Date) { + let messageSchema = try schema().message + let dateEditedSelection = messageSchema.has(.dateEdited) + ? "COALESCE((SELECT MAX(\(MessageTable.Column.dateEdited.sqlName)) FROM \(MessageTable.sqlName)), 0)" + : "0" + let sql = """ + SELECT + COALESCE((SELECT \(SQLiteSequenceTable.Column.seq.sqlName) FROM \(SQLiteSequenceTable.sqlName) WHERE \(SQLiteSequenceTable.Column.name.sqlName) = '\(MessageTable.sqlName)'), 0), + COALESCE((SELECT MAX(\(MessageTable.Column.dateRead.sqlName)) FROM \(MessageTable.sqlName)), 0), + \(dateEditedSelection) + """ + + return try read { db in + try fetchAllRowsCached(db: db, sql: sql).map { row in + ( + lastRowID: row.optionalInt(at: 0) ?? 0, + lastDateRead: row.imCoreDate(at: 1) ?? Date(nanosecondsSinceReferenceDate: 0), + lastDateEdited: row.imCoreDate(at: 2) ?? Date(nanosecondsSinceReferenceDate: 0) + ) + }.first + } ?? ( + lastRowID: 0, + lastDateRead: Date(nanosecondsSinceReferenceDate: 0), + lastDateEdited: Date(nanosecondsSinceReferenceDate: 0) + ) + } + func sentMessageIDs(since rowID: Int) throws -> [(rowID: Int, guid: String)] { - let statement = try cachedStatement(forEscapedSQL: """ - SELECT ROWID, guid - FROM message - WHERE is_from_me = 1 AND ROWID > ? - """).reset() - try statement.bind(rowID) - return try statement.compactMapRowsUntilDone { row in - guard let rowID = try row[0].optionalConverting(Int.self), - let guid = try row[1].optionalConverting(String.self) else { - return nil + try read { db in + try fetchAllRowsCached(db: db, sql: """ + SELECT ROWID, guid + FROM message + WHERE is_from_me = 1 AND ROWID > ? + """, arguments: [rowID]).compactMap { row in + guard let rowID = row.optionalInt(at: 0), + let guid = row.optionalString(at: 1) else { + return nil + } + return (rowID, guid) } - return (rowID, guid) } } func threadIDForMessage(rowID: Int) throws -> String? { - let statement = try cachedStatement(forEscapedSQL: """ - SELECT t.guid - FROM message AS m - LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID - LEFT JOIN chat AS t ON cmj.chat_id = t.ROWID - WHERE m.ROWID = ? - """).reset() - try statement.bind(rowID) - return try statement.compactMapRowsUntilDone { row in - try row[0].optionalConverting(String.self) - }.first + try read { db in + try fetchOneCached(String.self, db: db, sql: """ + SELECT t.guid + FROM message AS m + LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + LEFT JOIN chat AS t ON cmj.chat_id = t.ROWID + WHERE m.ROWID = ? + """, arguments: [rowID]) + } } func allThreadGUIDs() throws -> [String] { - let statement = try cachedStatement(forEscapedSQL: "SELECT guid FROM chat").reset() - return try statement.compactMapRowsUntilDone { row in - try row[0].optional(String.self) + try read { db in + try fetchAllRowsCached(db: db, sql: "SELECT guid FROM chat").compactMap { $0[0] as String? } } } @@ -88,13 +112,13 @@ public extension IMDatabase { direction: MappedPageDirection?, limit: Int = 20 ) throws -> [MappedMessageRow] { - let messageColumns = try tableColumns("message") + let messageSchema = try schema().message let withCursor = cursor.flatMap { Int($0) }.map { (cursor: $0, direction: direction ?? .before) } let comparisonOperator = withCursor.map { $0.direction == .after ? ">" : "<" } let order = withCursor?.direction == .after ? "ASC" : "DESC" - let dateExpression = comparisonOperator == ">" && messageColumns.contains("date_edited") - ? "MAX(m.date, COALESCE(m.date_edited, 0))" - : "cmj.message_date" + let dateExpression = comparisonOperator == ">" && messageSchema.has(.dateEdited) + ? "MAX(m.\(MessageTable.Column.date.sqlName), COALESCE(m.\(MessageTable.Column.dateEdited.sqlName), 0))" + : "cmj.\(ChatMessageJoinTable.Column.messageDate.sqlName)" // The historical query filtered by chat guid after starting from // `message ORDER BY date`. On large databases that can walk a huge @@ -106,7 +130,7 @@ public extension IMDatabase { var sql = """ SELECT - \(messageSelectionSQL(messageColumns: messageColumns)) + \(messageSelectionSQL(messageSchema: messageSchema)) FROM chat_message_join AS cmj \(messageJoinsFromChatMessageJoin) WHERE cmj.chat_id = ? @@ -116,57 +140,75 @@ public extension IMDatabase { } sql += "\nORDER BY cmj.message_date \(order), cmj.message_id \(order)\nLIMIT \(limit)" - let statement = try Statement.prepare(escapedSQL: sql, for: database) - if let withCursor { - try statement.bind(chatRowID, withCursor.cursor) - } else { - try statement.bind(chatRowID) + return try read { db in + if let withCursor { + return try MappedMessageRow.fetchAllMapped(db, sql: sql, arguments: sqlArguments([chatRowID, withCursor.cursor])) + } + return try MappedMessageRow.fetchAllMapped(db, sql: sql, arguments: [chatRowID]) } - - return try statement.mapRowsUntilDone(MappedMessageRow.self) } func mappedChatRowID(guid: String) throws -> Int? { - let statement = try cachedStatement(forEscapedSQL: "SELECT ROWID FROM chat WHERE guid = ?").reset() - try statement.bind(guid) - return try statement.compactMapRowsUntilDone { row in - try row[0].optionalConverting(Int.self) - }.first + try read { db in + try fetchOneCached(Int.self, db: db, sql: "SELECT ROWID FROM chat WHERE guid = ?", arguments: [guid]) + } } func mappedMessageRow(guid: String) throws -> MappedMessageRow? { - let messageColumns = try tableColumns("message") + try mappedMessageRows(guids: [guid]).first + } + + func mappedMessageRows(guids: [String]) throws -> [MappedMessageRow] { + guard !guids.isEmpty else { return [] } + let uniqueGUIDs = Array(OrderedSet(guids)) + + guard uniqueGUIDs.count <= maxMappedMessageRowsBatchSize else { + return try uniqueGUIDs + .chunks(ofCount: maxMappedMessageRowsBatchSize) + .flatMap { try mappedMessageRows(guids: Array($0)) } + } + + let messageSchema = try schema().message let sql = """ SELECT - \(messageSelectionSQL(messageColumns: messageColumns)) + \(messageSelectionSQL(messageSchema: messageSchema)) FROM message AS m \(messageJoins) - WHERE m.guid = ? + WHERE m.guid IN (\(placeholders(count: uniqueGUIDs.count))) """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - try statement.bind(guid) - return try statement.mapRowsUntilDone(MappedMessageRow.self).first + return try read { db in + try MappedMessageRow.fetchAllMapped(db, sql: sql, arguments: StatementArguments(uniqueGUIDs)) + } } func mappedMessageRows(rowIDs: [Int]) throws -> [MappedMessageRow] { guard !rowIDs.isEmpty else { return [] } - let messageColumns = try tableColumns("message") + let uniqueRowIDs = Array(OrderedSet(rowIDs)) + + guard uniqueRowIDs.count <= maxMappedMessageRowsBatchSize else { + return try uniqueRowIDs + .chunks(ofCount: maxMappedMessageRowsBatchSize) + .flatMap { try mappedMessageRows(rowIDs: Array($0)) } + .sorted { ($0.date ?? 0) > ($1.date ?? 0) } + } + + let messageSchema = try schema().message let sql = """ SELECT - \(messageSelectionSQL(messageColumns: messageColumns)) + \(messageSelectionSQL(messageSchema: messageSchema)) FROM message AS m \(messageJoins) - WHERE m.ROWID IN (\(placeholders(count: rowIDs.count))) + WHERE m.ROWID IN (\(placeholders(count: uniqueRowIDs.count))) ORDER BY m.date DESC """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - try statement.bind(rowIDs.map { $0 as any SQLiteBindable }) - return try statement.mapRowsUntilDone(MappedMessageRow.self) + return try read { db in + try MappedMessageRow.fetchAllMapped(db, sql: sql, arguments: StatementArguments(uniqueRowIDs)) + } } func mappedLatestMessageRows(chatRowIDs: [Int]) throws -> [String: MappedMessageRow] { guard !chatRowIDs.isEmpty else { return [:] } - let messageColumns = try tableColumns("message") + let messageSchema = try schema().message let sql = """ WITH requested_chat(rowid) AS ( VALUES \(rowValuePlaceholders(count: chatRowIDs.count)) @@ -184,16 +226,16 @@ public extension IMDatabase { FROM requested_chat ) SELECT - \(messageSelectionSQL(messageColumns: messageColumns)) + \(messageSelectionSQL(messageSchema: messageSchema)) FROM latest_join \(latestMessageJoins) ORDER BY m.date DESC """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - try statement.bind(chatRowIDs.map { $0 as any SQLiteBindable }) - return try statement.mapRowsUntilDone(MappedMessageRow.self).reduce(into: [:]) { result, messageRow in - guard let threadID = messageRow.threadID else { return } - result[threadID] = messageRow + return try read { db in + try MappedMessageRow.fetchAllMapped(db, sql: sql, arguments: StatementArguments(chatRowIDs)).reduce(into: [:]) { result, messageRow in + guard let threadID = messageRow.threadID else { return } + result[threadID] = messageRow + } } } @@ -206,35 +248,33 @@ public extension IMDatabase { LEFT JOIN attachment AS a ON a.ROWID = maj.attachment_id WHERE m.ROWID IN (\(placeholders(count: messageRowIDs.count))) """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - try statement.bind(messageRowIDs.map { $0 as any SQLiteBindable }) - return try statement.mapRowsUntilDone(MappedAttachmentRow.self) + return try read { db in + try MappedAttachmentRow.fetchAllMapped(db, sql: sql, arguments: StatementArguments(messageRowIDs)) + } } func attachmentFilename(guid: String) throws -> String? { - let statement = try cachedStatement(forEscapedSQL: "SELECT filename FROM attachment WHERE guid = ?").reset() - try statement.bind(guid) - return try statement.compactMapRowsUntilDone { row in - try row[0].optional(String.self) - }.first + try read { db in + try fetchOneCached(String.self, db: db, sql: "SELECT filename FROM attachment WHERE guid = ?", arguments: [guid]) + } } func attachmentFilename(messageRowID: Int) throws -> String? { - let statement = try cachedStatement(forEscapedSQL: """ - SELECT a.filename FROM message_attachment_join AS maj - INNER JOIN attachment AS a ON a.ROWID = maj.attachment_id - WHERE maj.message_id = ? - """).reset() - try statement.bind(messageRowID) - return try statement.compactMapRowsUntilDone { row in - try row[0].optional(String.self) - }.first + try read { db in + try fetchOneCached(String.self, db: db, sql: """ + SELECT a.filename FROM message_attachment_join AS maj + INNER JOIN attachment AS a ON a.ROWID = maj.attachment_id + WHERE maj.message_id = ? + """, arguments: [messageRowID]) + } } func mappedReactionRows(messageGUIDs: [String], chatRowIDs: [Int]) throws -> [MappedReactionMessageRow] { guard !messageGUIDs.isEmpty, !chatRowIDs.isEmpty else { return [] } - let messageColumns = try tableColumns("message") - let emojiColumn = messageColumns.contains("associated_message_emoji") ? "associated_message_emoji," : "" + let messageSchema = try schema().message + let emojiColumn = messageSchema.has(.associatedMessageEmoji) + ? "m.\(MessageTable.Column.associatedMessageEmoji.sqlName) AS \(MessageTable.Column.associatedMessageEmoji.sqlName)," + : "" let messageGUIDPlaceholders = messageGUIDs.map { _ in "?" }.joined(separator: ",") let chatRowIDPlaceholders = chatRowIDs.map { _ in "?" }.joined(separator: ",") let sql = """ @@ -244,12 +284,12 @@ public extension IMDatabase { LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID WHERE REPLACE(SUBSTR(associated_message_guid, INSTR(associated_message_guid, '/') + 1), 'bp:', '') IN (\(messageGUIDPlaceholders)) AND cmj.chat_id IN (\(chatRowIDPlaceholders)) + ORDER BY m.ROWID ASC """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - var bindings = messageGUIDs.map { $0 as any SQLiteBindable } - bindings.append(contentsOf: chatRowIDs.map { $0 as any SQLiteBindable }) - try statement.bind(bindings) - return try statement.mapRowsUntilDone(MappedReactionMessageRow.self) + return try read { db in + let bindings = messageGUIDs.map { $0 as Any } + chatRowIDs.map { $0 as Any } + return try MappedReactionMessageRow.fetchAllMapped(db, sql: sql, arguments: sqlArguments(bindings)) + } } func mappedReactionRows(messageGUIDs: [String], chatRowID: Int) throws -> [MappedReactionMessageRow] { @@ -264,9 +304,11 @@ public extension IMDatabase { } } -private func messageSelectionSQL(messageColumns: [String]) -> String { +private let maxMappedMessageRowsBatchSize = 500 + +private func messageSelectionSQL(messageSchema: TableSchema) -> String { var selections = ["m.ROWID AS ROWID"] - selections += messageColumns + selections += messageSchema.columns .filter { $0 != "ROWID" } .map { "m.\($0) AS \($0)" } selections += [ diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedShared.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedShared.swift index ad758d04..a3f30723 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedShared.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedShared.swift @@ -1,11 +1,10 @@ import Foundation -import SQLite +import GRDB extension Database { func tableColumns(_ tableName: String) throws -> [String] { - let statement = try Statement.prepare(escapedSQL: "PRAGMA table_info(\(tableName))", for: self) - return try statement.mapRowsUntilDone { row in - try row[1].expect(String.self) + try Row.fetchAll(self, SQLRequest(sql: "PRAGMA table_info(\(tableName))", cached: true)).map { row in + row[1] as String } } } @@ -15,8 +14,42 @@ extension IMDatabase { if let cached = tableColumnCache[tableName] { return cached } - let columns = try database.tableColumns(tableName) + let columns = try read { db in + try db.tableColumns(tableName) + } tableColumnCache[tableName] = columns return columns } } + +func sqlArguments(_ values: [Any]) -> StatementArguments { + guard let arguments = StatementArguments(values) else { + preconditionFailure("all SQL arguments must be database values") + } + return arguments +} + +func fetchOneCached( + _ type: T.Type, + db: Database, + sql: String, + arguments: StatementArguments = StatementArguments() +) throws -> T? { + try T.fetchOne(db, SQLRequest(sql: sql, arguments: arguments, cached: true)) +} + +func fetchAllRowsCached( + db: Database, + sql: String, + arguments: StatementArguments = StatementArguments() +) throws -> [Row] { + try Row.fetchAll(db, SQLRequest(sql: sql, arguments: arguments, cached: true)) +} + +func fetchCursorRowsCached( + db: Database, + sql: String, + arguments: StatementArguments = StatementArguments() +) throws -> RowCursor { + try Row.fetchCursor(db, SQLRequest(sql: sql, arguments: arguments, cached: true)) +} diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedThreads.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedThreads.swift index 0a4311df..95bb9363 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedThreads.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+MappedThreads.swift @@ -1,4 +1,4 @@ -import SQLite +import GRDB public let mappedThreadsLimit = 25 @@ -8,12 +8,12 @@ public extension IMDatabase { direction: MappedPageDirection?, limit: Int = mappedThreadsLimit ) throws -> [MappedChatRow] { - let chatColumns = try tableColumns("chat") + let chatSchema = try schema().chat let withCursor = cursor.flatMap { Int($0) }.map { (cursor: $0, direction: direction ?? .before) } let comparisonOperator = withCursor.map { $0.direction == .after ? ">" : "<" } var sql = """ SELECT - \(chatSelectionSQL(chatColumns: chatColumns)), + \(chatSelectionSQL(chatSchema: chatSchema)), (SELECT MAX(message_date) FROM chat_message_join WHERE chat_id = chat.ROWID) AS msgDate FROM chat """ @@ -22,25 +22,26 @@ public extension IMDatabase { } sql += "\nORDER BY msgDate DESC\nLIMIT \(limit)" - let statement = try Statement.prepare(escapedSQL: sql, for: database) - if let withCursor { - try statement.bind(withCursor.cursor) + return try read { db in + if let withCursor { + return try MappedChatRow.fetchAllMapped(db, sql: sql, arguments: [withCursor.cursor]) + } + return try MappedChatRow.fetchAllMapped(db, sql: sql) } - return try statement.mapRowsUntilDone(MappedChatRow.self) } func mappedThreadRow(guid: String) throws -> MappedChatRow? { - let chatColumns = try tableColumns("chat") + let chatSchema = try schema().chat let sql = """ SELECT - \(chatSelectionSQL(chatColumns: chatColumns)), + \(chatSelectionSQL(chatSchema: chatSchema)), (SELECT MAX(message_date) FROM chat_message_join WHERE chat_id = chat.ROWID) AS msgDate FROM chat WHERE chat.guid = ? """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - try statement.bind(guid) - return try statement.mapRowsUntilDone(MappedChatRow.self).first + return try read { db in + try MappedChatRow.fetchAllMapped(db, sql: sql, arguments: [guid]).first + } } func mappedThreadParticipantRows(chatRowIDs: [Int]) throws -> [Int: [MappedHandleRow]] { @@ -51,10 +52,10 @@ public extension IMDatabase { LEFT JOIN chat_handle_join AS chj ON chj.handle_id = handle.ROWID WHERE chat_id IN (\(chatRowIDs.map { _ in "?" }.joined(separator: ", "))) """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - try statement.bind(chatRowIDs.map { $0 as any SQLiteBindable }) - return try statement.mapRowsUntilDone(MappedHandleRow.self).reduce(into: [:]) { result, row in - result[row.chatID ?? -1, default: []].append(row) + return try read { db in + try MappedHandleRow.fetchAllMapped(db, sql: sql, arguments: StatementArguments(chatRowIDs)).reduce(into: [:]) { result, row in + result[row.chatID ?? -1, default: []].append(row) + } } } @@ -75,19 +76,19 @@ public extension IMDatabase { GROUP BY cm.chat_id """ - let statement = try Statement.prepare(escapedSQL: sql, for: database) - try statement.bind(chatRowIDs.map { $0 as any SQLiteBindable }) - return try statement.mapRowsUntilDone { row in - (try row[0].expectConverting(Int.self), try row[1].expectConverting(Int.self)) - }.reduce(into: [:]) { result, pair in - result[pair.0] = pair.1 + return try read { db in + try fetchAllRowsCached(db: db, sql: sql, arguments: StatementArguments(chatRowIDs)).map { row in + (row.requiredInt(at: 0), row.requiredInt(at: 1)) + }.reduce(into: [:]) { result, pair in + result[pair.0] = pair.1 + } } } } -private func chatSelectionSQL(chatColumns: [String]) -> String { +private func chatSelectionSQL(chatSchema: TableSchema) -> String { var selections = ["chat.ROWID AS ROWID"] - selections += chatColumns + selections += chatSchema.columns .filter { $0 != "ROWID" } .map { "chat.\($0) AS \($0)" } return selections.joined(separator: ",\n") diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Messages.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Messages.swift index cb32c0f1..dc737be3 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Messages.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Messages.swift @@ -1,6 +1,6 @@ import Collections import Foundation -import SQLite +import GRDB public enum DateOrdering { case newestFirst @@ -41,19 +41,20 @@ public extension IMDatabase { with guid: GUID, withAttachments includeAttachments: Bool = true, ) throws -> (message: Message, chatGUID: GUID)? { - let statement = try cachedStatement(forEscapedSQL: """ - \(messagesQuerySharedPrelude) - WHERE m.guid = ? - """).reset() - try statement.bind(guid) + let result = try read { db in + try Row.fetchAll(db, sql: """ + \(messagesQuerySharedPrelude) + WHERE m.guid = ? + """, arguments: [guid]).compactMap { row -> (Message, GUID)? in + guard let chatGUID = row.optionalString(at: 0) else { + // drop orphaned (not within a chat) messages + return nil + } + return try (Message(row: row), GUID(chatGUID)) + }.first + } - guard let (initialMessage, chatGUID) = try statement.compactMapRowsUntilDone({ row -> (Message, GUID)? in - guard let chatGUID = try row[0].optionalConverting(String.self) else { - // drop orphaned (not within a chat) messages - return nil - } - return try (Message(row: row), GUID(chatGUID)) - }).first else { + guard let (initialMessage, chatGUID) = result else { return nil } @@ -72,19 +73,19 @@ public extension IMDatabase { limit: Int = 50, withAttachments includeAttachments: Bool = true, ) throws -> some Collection { - let statement = try cachedStatement(forEscapedSQL: """ - \(messagesQuerySharedPrelude) - WHERE c.guid = ? - \(filter.map { "AND m.\($0.sqlFragment)" } ?? "") - ORDER BY m.date \(order.sqlKeyword) - LIMIT ? - """).reset() - try statement.bind(chatGUID, limit) - var messages = OrderedDictionary() - try statement.stepUntilDone { row in - let message = try Message(row: row) - messages[message.id] = message + try read { db in + let rows = try Row.fetchAll(db, sql: """ + \(messagesQuerySharedPrelude) + WHERE c.guid = ? + \(filter.map { "AND m.\($0.sqlFragment)" } ?? "") + ORDER BY m.date \(order.sqlKeyword) + LIMIT ? + """, arguments: sqlArguments([chatGUID, limit])) + for row in rows { + let message = try Message(row: row) + messages[message.id] = message + } } if includeAttachments { @@ -96,24 +97,24 @@ public extension IMDatabase { } private extension Message { - init(row: borrowing Row) throws { + init(row: Row) throws { // (skipping `c.guid`) self = try Message( - id: row[1].expect(Int.self), - guid: GUID(row[2].expect(String.self)), - balloonBundleID: try row[3].optional(String.self), - threadOriginatorGUID: try row[4].optional(String.self).map(GUID.init(stringLiteral:)), - text: row[5].optional(String.self).map { + id: row.requiredInt(at: 1), + guid: GUID(row.requiredString(at: 2)), + balloonBundleID: row.optionalString(at: 3), + threadOriginatorGUID: row.optionalString(at: 4).map(GUID.init(stringLiteral:)), + text: row.optionalString(at: 5).map { Sensitive(.messageText, hiding: $0) }, - attributedBody: row[6].optional(Data.self).flatMap { + attributedBody: row.optionalData(at: 6).flatMap { try Sensitive(.messageAttributedBody, hiding: AttributedBodyDecoder.attributedString(from: $0)) }, - isFromMe: row[7].looseBool(), - isSent: row[8].looseBool(), - date: row[9].imCoreDate(), - dateRead: row[10].imCoreDate(), - summaryInfo: row[11].optionalConverting(Data.self).map(Message.SummaryInfo.init(blob:)), + isFromMe: row.looseBool(at: 7), + isSent: row.looseBool(at: 8), + date: row.imCoreDate(at: 9), + dateRead: row.imCoreDate(at: 10), + summaryInfo: row.optionalData(at: 11).map(Message.SummaryInfo.init(blob:)), ) } } diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Search.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Search.swift index b4d694c6..214bd76f 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Search.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Search.swift @@ -1,5 +1,5 @@ import Foundation -import SQLite +import GRDB public extension IMDatabase { /// Searches messages by text content, properly decoding attributedBody. @@ -61,43 +61,38 @@ public extension IMDatabase { // Fetch more than limit to account for filtering - we'll filter in Swift after decoding let fetchLimit = limit * 20 - let statement = try cachedStatement(forEscapedSQL: sql).reset() - - // Bind parameters in order - if let chatGUID { - try statement.bind(chatGUID, fetchLimit) - } else { - try statement.bind(fetchLimit) - } - var matchingRowIDs: [Int] = [] + let arguments = chatGUID.map { sqlArguments([$0, fetchLimit]) } ?? StatementArguments([fetchLimit]) - try statement.stepUntilDone { row in - // Stop once we have enough results - guard matchingRowIDs.count < limit else { return } + try read { db in + let cursor = try fetchCursorRowsCached(db: db, sql: sql, arguments: arguments) + while let row = try cursor.next() { + // Stop once we have enough results + guard matchingRowIDs.count < limit else { break } - let rowID = try row[0].expect(Int.self) - let plainText = try row[1].optional(String.self) - let attributedBodyData = try row[2].optional(Data.self) + let rowID = row.requiredInt(at: 0) + let plainText = row.optionalString(at: 1) + let attributedBodyData = row.optionalData(at: 2) - // Try to get text from attributedBody first (more complete), fall back to text column - var messageText: String? + // Try to get text from attributedBody first (more complete), fall back to text column + var messageText: String? - if let data = attributedBodyData { - messageText = try? AttributedBodyDecoder.plainText(from: data) - } + if let data = attributedBodyData { + messageText = try? AttributedBodyDecoder.plainText(from: data) + } - // Fall back to plain text column - if messageText == nil || messageText?.isEmpty == true { - messageText = plainText - } + // Fall back to plain text column + if messageText == nil || messageText?.isEmpty == true { + messageText = plainText + } - // Check if the decoded text actually contains the search query (case-insensitive) - guard let text = messageText, text.lowercased().contains(queryLower) else { - return - } + // Check if the decoded text actually contains the search query (case-insensitive) + guard let text = messageText, text.lowercased().contains(queryLower) else { + continue + } - matchingRowIDs.append(rowID) + matchingRowIDs.append(rowID) + } } return matchingRowIDs diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Unreads.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Unreads.swift index bb301a99..b6faecd1 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Unreads.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Unreads.swift @@ -1,14 +1,10 @@ import Foundation -import Logging -import SQLite +import GRDB import IMessageCore -private let log = Logger(label: "imdb.unreads") - // TODO(skip): optimize; query takes ~70ms (!) let unreadStatesQuery = """ SELECT - c.ROWID AS chat_id, c.guid AS chat_guid, COUNT( CASE @@ -55,23 +51,19 @@ public extension IMDatabase { return (unreadCounts[chat.id] ?? 0) == 0 } - func chatStates() throws -> [ChatRef: ChatState] { - let statement = try cachedStatement(forEscapedSQL: unreadStatesQuery) - try statement.reset() - - var chatStates: [ChatRef: ChatState] = [:] + func chatStates() throws -> [String: ChatState] { + var chatStates: [String: ChatState] = [:] - try statement.stepUntilDone { row in - guard let chatRef: ChatRef = try ChatRef(rowID: row[0].optional(Int.self), guid: row[1].optional(String.self)) else { - log.warning("while querying unread states: some chat had neither a rowid nor a guid. can't really do much with this") - return - } + try read { db in + for row in try fetchAllRowsCached(db: db, sql: unreadStatesQuery) { + let chatGUID = row.requiredString(at: 0) - let lastReadMessageTimestamp = try Date(nanosecondsSinceReferenceDate: row[3].expect(Int.self)) + let lastReadMessageTimestamp = Date(nanosecondsSinceReferenceDate: row.requiredInt(at: 2)) - let unreadCount: Int = try row[2].expect(Int.self) + let unreadCount = row.requiredInt(at: 1) - chatStates[chatRef] = ChatState(unreadCount: unreadCount, lastReadMessageTimestamp: lastReadMessageTimestamp) + chatStates[chatGUID] = ChatState(unreadCount: unreadCount, lastReadMessageTimestamp: lastReadMessageTimestamp) + } } return chatStates diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Updates.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Updates.swift index 972dc7c0..62829ddd 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Updates.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase+Updates.swift @@ -1,108 +1,128 @@ import Foundation +import GRDB import Logging private let log = Logger(label: "imdb.updates") -let updatedChatsSinceQuery = """ -SELECT - m.ROWID, - m.date_read, - m.date_edited, - c.ROWID, - c.guid -FROM - message m -LEFT JOIN chat_message_join cmj ON cmj.message_id = m.ROWID -LEFT JOIN chat c ON cmj.chat_id = c.ROWID -WHERE - m.ROWID > ? OR m.date_read > ? OR m.date_edited > ? -GROUP BY - c.guid -ORDER BY - date DESC -""" - -public struct UpdatedChatsQueryResult { - public var updatedChats: [ChatRef] - /// This maximum is local to the set of updated chats. - public var latestMessageRowID: Int? - /// This maximum is local to the set of updated chats. - public var latestMessageDateRead: Date? - public var latestDateEdited: Date? +package struct UpdatedMessageChange { + package var rowID: Int + package var chatGUID: String + package var isNew: Bool + package var wasRead: Bool + package var wasEdited: Bool } -public extension IMDatabase { - func chats(withMessagesNewerThanRowID lastRowID: Int, orReadSince lastDateRead: Date, orEditedSince lastDateEdited: Date) throws -> UpdatedChatsQueryResult { - let statement = try cachedStatement(forEscapedSQL: updatedChatsSinceQuery) - - try statement.reset() - try statement.bind(lastRowID, lastDateRead.nanosecondsSinceReferenceDate, lastDateEdited.nanosecondsSinceReferenceDate) +package struct UpdatedMessagesQueryResult { + package var updatedMessages: [UpdatedMessageChange] + /// This maximum is local to the set of newly inserted message rows. + package var latestMessageRowID: Int? + /// This maximum is local to the set of read updates. + package var latestMessageDateRead: Date? + /// This maximum is local to the set of edit updates. + package var latestDateEdited: Date? +} +extension IMDatabase { + package + func messages(newerThanRowID lastRowID: Int, orReadSince lastDateRead: Date, orEditedSince lastDateEdited: Date) throws -> UpdatedMessagesQueryResult { + let messageSchema = try schema().message + let dateEditedExpression = messageSchema.has(.dateEdited) + ? "m.\(MessageTable.Column.dateEdited.sqlName)" + : "0" var newestMessageRowID: Int? var latestMessageDateRead: Date? var latestDateEdited: Date? var timesWarnedAboutOrphanedMessage = 0 - let updatedChats: [ChatRef] = try statement.compactMapRowsUntilDone { row in - let messageRowID = try row[0].expect(Int.self) - newestMessageRowID = max(messageRowID, newestMessageRowID ?? 0) + let updatedMessages: [UpdatedMessageChange] = try read { db in + try Row.fetchAll( + db, + sql: updatedMessagesSinceQuery(dateEditedExpression: dateEditedExpression), + arguments: StatementArguments([ + lastRowID, + lastDateRead.nanosecondsSinceReferenceDate, + lastDateEdited.nanosecondsSinceReferenceDate, + ]) + ).compactMap { row in + let messageRowID = row.requiredInt(at: 0) + let isNew = messageRowID > lastRowID + if isNew { + newestMessageRowID = max(messageRowID, newestMessageRowID ?? 0) + } - dateRead: do { - // IMCore typically uses `0` to represent absence, but fall back - // to `0` explicitly just in case. - let nanoseconds = try row[1].optional(Int.self) ?? 0 + var wasRead = false + var wasEdited = false - // If the message hasn't been read yet or has a bogus read date, - // then don't update the "latest read date" at all. I'm not sure - // what causes bogus read dates, but if you let it leak into the - // rest of the program then it can cause an integer overflow - // crash. - guard nanoseconds > 0, nanoseconds < .max else { - break dateRead + if let dateRead = row.imCoreDate(at: 1) { + wasRead = dateRead > lastDateRead + if wasRead { + latestMessageDateRead = if let latestMessageDateRead { + max(dateRead, latestMessageDateRead) + } else { + dateRead + } + } } - let dateRead = Date(nanosecondsSinceReferenceDate: nanoseconds) - latestMessageDateRead = if let latestMessageDateRead, dateRead < .distantFuture { - max(dateRead, latestMessageDateRead) - } else { - dateRead + if let dateEdited = row.imCoreDate(at: 2) { + wasEdited = dateEdited > lastDateEdited + if wasEdited { + latestDateEdited = if let latestDateEdited { + max(dateEdited, latestDateEdited) + } else { + dateEdited + } + } } - } - dateEdited: do { - let nanoseconds = try row[2].optional(Int.self) ?? 0 - guard nanoseconds > 0, nanoseconds < .max else { break dateEdited } - let dateEdited = Date(nanosecondsSinceReferenceDate: nanoseconds) - latestDateEdited = if let latestDateEdited, dateEdited < .distantFuture { - max(dateEdited, latestDateEdited) - } else { - dateEdited + guard let guid = row.optionalString(at: 3) else { + // For whatever reason it's possible for messages to not be + // joinable with chats. Right now I have one of these for a SMS + // TOTP verification code, which might've been automatically + // deleted in a weird way due to the autofill feature. + // + // In case there are tons of orphaned messages, don't spam the + // logs with this message. + if timesWarnedAboutOrphanedMessage < 10 { + log.error("couldn't join message \(messageRowID) to chat, dropping") + timesWarnedAboutOrphanedMessage += 1 + } + return nil } - } - guard let rowID = try row[3].optional(Int.self), let guid = try row[4].optional(String.self) else { - // For whatever reason it's possible for messages to not be - // joinable with chats. Right now I have one of these for a SMS - // TOTP verification code, which might've been automatically - // deleted in a weird way due to the autofill feature. - // - // In case there are tons of orphaned messages, don't spam the - // logs with this message. - if timesWarnedAboutOrphanedMessage < 10 { - log.error("couldn't join message \(messageRowID) to chat, dropping") - timesWarnedAboutOrphanedMessage += 1 - } - return nil + return UpdatedMessageChange( + rowID: messageRowID, + chatGUID: guid, + isNew: isNew, + wasRead: wasRead, + wasEdited: wasEdited + ) } - - return ChatRef(rowID: rowID, guid: guid) } - return UpdatedChatsQueryResult( - updatedChats: updatedChats, + return UpdatedMessagesQueryResult( + updatedMessages: updatedMessages, latestMessageRowID: newestMessageRowID, latestMessageDateRead: latestMessageDateRead, latestDateEdited: latestDateEdited ) } } + +private func updatedMessagesSinceQuery(dateEditedExpression: String) -> String { + """ + SELECT + m.\(MessageTable.Column.rowID.sqlName), + m.\(MessageTable.Column.dateRead.sqlName), + \(dateEditedExpression) AS \(MessageTable.Column.dateEdited.sqlName), + c.\(ChatTable.Column.guid.sqlName) + FROM + \(MessageTable.sqlName) m + LEFT JOIN \(ChatMessageJoinTable.sqlName) cmj ON cmj.\(ChatMessageJoinTable.Column.messageID.sqlName) = m.\(MessageTable.Column.rowID.sqlName) + LEFT JOIN \(ChatTable.sqlName) c ON cmj.\(ChatMessageJoinTable.Column.chatID.sqlName) = c.\(ChatTable.Column.rowID.sqlName) + WHERE + m.\(MessageTable.Column.rowID.sqlName) > ? OR m.\(MessageTable.Column.dateRead.sqlName) > ? OR \(dateEditedExpression) > ? + ORDER BY + m.\(MessageTable.Column.rowID.sqlName) ASC + """ +} diff --git a/src/IMessage/Sources/IMDatabase/Database/IMDatabase.swift b/src/IMessage/Sources/IMDatabase/Database/IMDatabase.swift index 75f87fab..05b8ceda 100644 --- a/src/IMessage/Sources/IMDatabase/Database/IMDatabase.swift +++ b/src/IMessage/Sources/IMDatabase/Database/IMDatabase.swift @@ -1,7 +1,7 @@ import AsyncAlgorithms import Foundation +import GRDB import Logging -import SQLite import IMessageCore private func chatDatabaseFile(in messagesDataURL: URL) -> URL { @@ -14,9 +14,9 @@ private func chatDatabaseWalFile(in messagesDataURL: URL) -> URL { private let log = Logger(label: "imdb") -private let messageIndexes = [ - ("message_idx_date_read", "date_read"), - ("message_idx_date_edited", "date_edited"), +private let messageIndexes: [(name: String, column: MessageTable.Column)] = [ + ("message_idx_date_read", .dateRead), + ("message_idx_date_edited", .dateEdited), ] public final class IMDatabase { @@ -40,10 +40,10 @@ public final class IMDatabase { public var noisy = false - var database: Database + var database: DatabaseQueue - private var statementCache = [String: Statement]() var tableColumnCache = [String: [String]]() + var schemaCache: IMDatabaseSchema? public init(messagesDataBaseURL: URL? = nil, createIndexes: Bool = false) throws { let messagesDataDirectory = messagesDataBaseURL ?? URL(fileURLWithPath: "\(NSHomeDirectory())/Library/Messages/") @@ -57,17 +57,13 @@ public final class IMDatabase { try Self.createIndexesIfNecessary(in: messagesDataDirectory) } - self.database = try Database(connecting: chatDatabaseFile(in: messagesDataDirectory).path, flags: .readOnly) + var configuration = Configuration() + configuration.readonly = true + self.database = try DatabaseQueue(path: chatDatabaseFile(in: messagesDataDirectory).path, configuration: configuration) } - func cachedStatement(forEscapedSQL sql: String) throws -> Statement { - if let cached = statementCache[sql] { - return cached - } - - let statement = try Statement.prepare(escapedSQL: sql, for: database, flags: .persistent) - statementCache[sql] = statement - return statement + func read(_ value: (Database) throws -> T) throws -> T { + try database.read(value) } deinit { @@ -81,11 +77,15 @@ public final class IMDatabase { private extension IMDatabase { static func createIndexesIfNecessary(in messagesDataDirectory: URL) throws { - let database = try Database(connecting: chatDatabaseFile(in: messagesDataDirectory).path, flags: .readWrite) - let messageColumns = try database.tableColumns("message") + let database = try DatabaseQueue(path: chatDatabaseFile(in: messagesDataDirectory).path) + let messageSchema = try database.read { db in + try TableSchema(columns: db.tableColumns(MessageTable.sqlName)) + } - for (indexName, columnName) in messageIndexes where messageColumns.contains(columnName) { - try database.execute(sqlWithoutEscaping: "CREATE INDEX IF NOT EXISTS \(indexName) ON message (\(columnName))") + try database.write { db in + for (indexName, column) in messageIndexes where messageSchema.has(column) { + try db.execute(sql: "CREATE INDEX IF NOT EXISTS \(indexName) ON \(MessageTable.sqlName) (\(column.sqlName))") + } } } } diff --git a/src/IMessage/Sources/IMDatabase/Models/GUID.swift b/src/IMessage/Sources/IMDatabase/Models/GUID.swift index 70190fce..10d5f8e3 100644 --- a/src/IMessage/Sources/IMDatabase/Models/GUID.swift +++ b/src/IMessage/Sources/IMDatabase/Models/GUID.swift @@ -1,4 +1,4 @@ -import SQLite +import GRDB public struct GUID: Sendable { var guts: String @@ -18,9 +18,16 @@ extension GUID: ExpressibleByStringLiteral { } } -extension GUID: SQLiteBindable { - public func unsafeBind(toPreparedStatement handle: OpaquePointer, at parameterIndex: Int32) throws { - try guts.unsafeBind(toPreparedStatement: handle, at: parameterIndex) +extension GUID: DatabaseValueConvertible { + public var databaseValue: DatabaseValue { + guts.databaseValue + } + + public static func fromDatabaseValue(_ dbValue: DatabaseValue) -> GUID? { + guard let string = String.fromDatabaseValue(dbValue) else { + return nil + } + return GUID(string) } } diff --git a/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows+DictionaryBridges.swift b/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows+DictionaryBridges.swift index 986b3388..dbd6e79b 100644 --- a/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows+DictionaryBridges.swift +++ b/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows+DictionaryBridges.swift @@ -28,6 +28,7 @@ public extension MappedMessageRow { payloadData = object.data("payload_data") expressiveSendStyleID = object.string("expressive_send_style_id") messageSummaryInfo = object.data("message_summary_info") + replyToGUID = object.string("reply_to_guid") threadOriginatorGUID = object.string("thread_originator_guid") threadOriginatorPart = object.string("thread_originator_part") dateRetracted = object.int("date_retracted") @@ -69,6 +70,7 @@ public extension MappedMessageRow { "payload_data": payloadData, "expressive_send_style_id": expressiveSendStyleID, "message_summary_info": messageSummaryInfo, + "reply_to_guid": replyToGUID, "thread_originator_guid": threadOriginatorGUID, "thread_originator_part": threadOriginatorPart, "date_retracted": dateRetracted, diff --git a/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows.swift b/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows.swift index d333af85..db1a9a42 100644 --- a/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows.swift +++ b/src/IMessage/Sources/IMDatabase/Models/MappedDatabaseRows.swift @@ -1,5 +1,5 @@ import Foundation -import SQLite +import GRDB /// Runtime counterparts to the historical TypeScript row shapes. Database /// queries decode into these structs directly. Legacy fixture/original-payload @@ -43,8 +43,34 @@ import SQLite /// - `is_recovered`: Added in Ventura. /// - `is_deleting_incoming_messages`, `is_pending_review`: Observed in Tahoe. -public protocol MappedDatabaseRow { - init(row: borrowing Row, columns: MappedRowColumnIndexes) throws +public protocol MappedDatabaseRow: FetchableRecord { + init(row: Row, columns: MappedRowColumnIndexes) throws +} + +public extension MappedDatabaseRow { + init(row: Row) throws { + try self.init(row: row, columns: MappedRowColumnIndexes(Array(row.columnNames))) + } + + static func fetchAllMapped( + _ db: Database, + sql: String, + arguments: StatementArguments = StatementArguments() + ) throws -> [Self] { + let request = SQLRequest(sql: sql, arguments: arguments, cached: true) + let cursor = try Row.fetchCursor(db, request) + var rows: [Self] = [] + var columns: MappedRowColumnIndexes? + + while let row = try cursor.next() { + if columns == nil { + columns = MappedRowColumnIndexes(Array(row.columnNames)) + } + rows.append(try Self(row: row, columns: columns!)) + } + + return rows + } } public struct MappedRowColumnIndexes { @@ -54,22 +80,11 @@ public struct MappedRowColumnIndexes { indexesByName = Dictionary(uniqueKeysWithValues: names.enumerated().map { ($0.element, $0.offset) }) } - public init(statement: Statement) { - self.init(statement.columnNames) - } - func index(for name: String) -> Int? { indexesByName[name] } } -public extension Statement { - func mapRowsUntilDone(_: T.Type) throws -> [T] { - let columns = MappedRowColumnIndexes(statement: self) - return try mapRowsUntilDone { try T(row: $0, columns: columns) } - } -} - public enum MappedDatabaseRowError: Error, CustomStringConvertible { case missingRequiredColumn(row: String, column: String) @@ -119,6 +134,9 @@ public struct MappedMessageRow: MappedDatabaseRow { public let payloadData: Data? public let expressiveSendStyleID: String? public let messageSummaryInfo: Data? + /// GUID of a related message. iMessage uses this for reaction removal rows + /// to point back at the hidden reaction-add message row. + public let replyToGUID: String? public let threadOriginatorGUID: String? public let threadOriginatorPart: String? /// Added in Ventura. Apple nanosecond timestamp. Stringify at JSON/API @@ -139,7 +157,7 @@ public struct MappedMessageRow: MappedDatabaseRow { public let participantID: String? public let otherID: String? - public init(row: borrowing Row, columns: MappedRowColumnIndexes) throws { + public init(row: Row, columns: MappedRowColumnIndexes) throws { rowID = try row.requiredInt("ROWID", columns: columns, row: Self.self) guid = try row.requiredString("guid", columns: columns, row: Self.self) text = try row.string("text", columns: columns) @@ -166,6 +184,7 @@ public struct MappedMessageRow: MappedDatabaseRow { payloadData = try row.data("payload_data", columns: columns) expressiveSendStyleID = try row.string("expressive_send_style_id", columns: columns) messageSummaryInfo = try row.data("message_summary_info", columns: columns) + replyToGUID = try row.string("reply_to_guid", columns: columns) threadOriginatorGUID = try row.string("thread_originator_guid", columns: columns) threadOriginatorPart = try row.string("thread_originator_part", columns: columns) dateRetracted = try row.int("date_retracted", columns: columns) @@ -199,7 +218,7 @@ public struct MappedChatRow: MappedDatabaseRow { // the `chat` table; they are computed SQL aliases. public let msgDate: Int? - public init(row: borrowing Row, columns: MappedRowColumnIndexes) throws { + public init(row: Row, columns: MappedRowColumnIndexes) throws { rowID = try row.requiredInt("ROWID", columns: columns, row: Self.self) guid = try row.requiredString("guid", columns: columns, row: Self.self) state = try row.int("state", columns: columns) ?? 0 @@ -258,7 +277,7 @@ public struct MappedAttachmentRow: MappedDatabaseRow { self.size = size } - public init(row: borrowing Row, columns: MappedRowColumnIndexes) throws { + public init(row: Row, columns: MappedRowColumnIndexes) throws { try self.init( msgRowID: row.requiredInt("msgRowID", columns: columns, row: Self.self), filename: row.string("filename", columns: columns), @@ -290,7 +309,7 @@ public struct MappedHandleRow: MappedDatabaseRow { self.uncanonicalizedID = uncanonicalizedID } - public init(row: borrowing Row, columns: MappedRowColumnIndexes) throws { + public init(row: Row, columns: MappedRowColumnIndexes) throws { chatID = try row.int("chat_id", columns: columns) participantID = try row.string("participantID", columns: columns) uncanonicalizedID = try row.string("uncanonicalized_id", columns: columns) @@ -310,7 +329,7 @@ public struct MappedReactionMessageRow: MappedDatabaseRow { // for the sender associated with the reaction message. public let participantID: String? - public init(row: borrowing Row, columns: MappedRowColumnIndexes) throws { + public init(row: Row, columns: MappedRowColumnIndexes) throws { rowID = try row.requiredInt("ROWID", columns: columns, row: Self.self) isFromMe = try row.requiredInt("is_from_me", columns: columns, row: Self.self) handleID = try row.int("handle_id", columns: columns) @@ -322,7 +341,7 @@ public struct MappedReactionMessageRow: MappedDatabaseRow { } private extension Row { - borrowing func requiredString( + func requiredString( _ key: String, columns: MappedRowColumnIndexes, row: RowType.Type @@ -333,7 +352,7 @@ private extension Row { return value } - borrowing func requiredInt( + func requiredInt( _ key: String, columns: MappedRowColumnIndexes, row: RowType.Type @@ -344,24 +363,24 @@ private extension Row { return value } - borrowing func string(_ key: String, columns: MappedRowColumnIndexes) throws -> String? { + func string(_ key: String, columns: MappedRowColumnIndexes) throws -> String? { guard let index = columns.index(for: key) else { return nil } - return try self[index].optionalConverting(String.self) + return self[index] as String? } - borrowing func int(_ key: String, columns: MappedRowColumnIndexes) throws -> Int? { + func int(_ key: String, columns: MappedRowColumnIndexes) throws -> Int? { guard let index = columns.index(for: key) else { return nil } - return try self[index].optionalConverting(Int.self) + return self[index] as Int? } - borrowing func data(_ key: String, columns: MappedRowColumnIndexes) throws -> Data? { + func data(_ key: String, columns: MappedRowColumnIndexes) throws -> Data? { guard let index = columns.index(for: key) else { return nil } - return try self[index].optionalConverting(Data.self) + return self[index] as Data? } } diff --git a/src/IMessage/Sources/IMDatabase/Schema/IMDatabaseSchema.swift b/src/IMessage/Sources/IMDatabase/Schema/IMDatabaseSchema.swift new file mode 100644 index 00000000..d17ea410 --- /dev/null +++ b/src/IMessage/Sources/IMDatabase/Schema/IMDatabaseSchema.swift @@ -0,0 +1,66 @@ +import GRDB + +protocol IMDatabaseColumn: CaseIterable, ColumnExpression, Hashable, RawRepresentable where RawValue == String {} + +extension IMDatabaseColumn { + var sqlName: String { name } +} + +protocol IMDatabaseTable: TableRecord { + associatedtype Column: IMDatabaseColumn +} + +extension IMDatabaseTable { + static var sqlName: String { databaseTableName } +} + +struct TableSchema { + let columns: [String] + + private let columnNames: Set + + init(columns: [String]) { + self.columns = columns + columnNames = Set(columns) + } + + func has(_ column: Table.Column) -> Bool { + columnNames.contains(column.sqlName) + } +} + +struct IMDatabaseSchema { + let sqliteSequence: TableSchema + let message: TableSchema + let chat: TableSchema + let handle: TableSchema + let attachment: TableSchema + let chatMessageJoin: TableSchema + let chatHandleJoin: TableSchema + let messageAttachmentJoin: TableSchema + + init(columnsFor: (String) throws -> [String]) throws { + sqliteSequence = try TableSchema(columns: columnsFor(SQLiteSequenceTable.sqlName)) + message = try TableSchema(columns: columnsFor(MessageTable.sqlName)) + chat = try TableSchema(columns: columnsFor(ChatTable.sqlName)) + handle = try TableSchema(columns: columnsFor(HandleTable.sqlName)) + attachment = try TableSchema(columns: columnsFor(AttachmentTable.sqlName)) + chatMessageJoin = try TableSchema(columns: columnsFor(ChatMessageJoinTable.sqlName)) + chatHandleJoin = try TableSchema(columns: columnsFor(ChatHandleJoinTable.sqlName)) + messageAttachmentJoin = try TableSchema(columns: columnsFor(MessageAttachmentJoinTable.sqlName)) + } +} + +extension IMDatabase { + func schema() throws -> IMDatabaseSchema { + if let schemaCache { + return schemaCache + } + + let loaded = try IMDatabaseSchema { tableName in + try tableColumns(tableName) + } + schemaCache = loaded + return loaded + } +} diff --git a/src/IMessage/Sources/IMDatabase/Schema/IMDatabaseTables.swift b/src/IMessage/Sources/IMDatabase/Schema/IMDatabaseTables.swift new file mode 100644 index 00000000..a9ddc936 --- /dev/null +++ b/src/IMessage/Sources/IMDatabase/Schema/IMDatabaseTables.swift @@ -0,0 +1,222 @@ +/// Table and column names observed in `fixtures/schema-monterey.sql`, +/// `fixtures/schema-ventura.sql`, and `fixtures/schema-tahoe.sql`. + +enum SQLiteSequenceTable: IMDatabaseTable { + static let databaseTableName = "sqlite_sequence" + + enum Column: String, IMDatabaseColumn { + case name + case seq + } +} + +enum MessageTable: IMDatabaseTable { + static let databaseTableName = "message" + + enum Column: String, IMDatabaseColumn { + case rowID = "ROWID" + case guid + case text + case replace + case serviceCenter = "service_center" + case handleID = "handle_id" + case subject + case country + case attributedBody = "attributedBody" + case version + case messageType = "type" + case service + case account + case accountGUID = "account_guid" + case error + case date + case dateRead = "date_read" + case dateDelivered = "date_delivered" + case isDelivered = "is_delivered" + case isFinished = "is_finished" + case isEmote = "is_emote" + case isFromMe = "is_from_me" + case isEmpty = "is_empty" + case isDelayed = "is_delayed" + case isAutoReply = "is_auto_reply" + case isPrepared = "is_prepared" + case isRead = "is_read" + case isSystemMessage = "is_system_message" + case isSent = "is_sent" + case hasDDResults = "has_dd_results" + case isServiceMessage = "is_service_message" + case isForward = "is_forward" + case wasDowngraded = "was_downgraded" + case isArchive = "is_archive" + case cacheHasAttachments = "cache_has_attachments" + case cacheRoomnames = "cache_roomnames" + case wasDataDetected = "was_data_detected" + case wasDeduplicated = "was_deduplicated" + case isAudioMessage = "is_audio_message" + case isPlayed = "is_played" + case datePlayed = "date_played" + case itemType = "item_type" + case otherHandle = "other_handle" + case groupTitle = "group_title" + case groupActionType = "group_action_type" + case shareStatus = "share_status" + case shareDirection = "share_direction" + case isExpirable = "is_expirable" + case expireState = "expire_state" + case messageActionType = "message_action_type" + case messageSource = "message_source" + case associatedMessageGUID = "associated_message_guid" + case associatedMessageType = "associated_message_type" + case balloonBundleID = "balloon_bundle_id" + case payloadData = "payload_data" + case expressiveSendStyleID = "expressive_send_style_id" + case associatedMessageRangeLocation = "associated_message_range_location" + case associatedMessageRangeLength = "associated_message_range_length" + case timeExpressiveSendPlayed = "time_expressive_send_played" + case messageSummaryInfo = "message_summary_info" + case ckSyncState = "ck_sync_state" + case ckRecordID = "ck_record_id" + case ckRecordChangeTag = "ck_record_change_tag" + case destinationCallerID = "destination_caller_id" + case isCorrupt = "is_corrupt" + case replyToGUID = "reply_to_guid" + case sortID = "sort_id" + case isSpam = "is_spam" + case hasUnseenMention = "has_unseen_mention" + case threadOriginatorGUID = "thread_originator_guid" + case threadOriginatorPart = "thread_originator_part" + case syndicationRanges = "syndication_ranges" + case syncedSyndicationRanges = "synced_syndication_ranges" + case wasDeliveredQuietly = "was_delivered_quietly" + case didNotifyRecipient = "did_notify_recipient" + case dateRetracted = "date_retracted" + case dateEdited = "date_edited" + case wasDetonated = "was_detonated" + case partCount = "part_count" + case isStewie = "is_stewie" + case isSOS = "is_sos" + case isCritical = "is_critical" + case biaReferenceID = "bia_reference_id" + case isKTVerified = "is_kt_verified" + case fallbackHash = "fallback_hash" + case associatedMessageEmoji = "associated_message_emoji" + case isPendingSatelliteSend = "is_pending_satellite_send" + case needsRelay = "needs_relay" + case scheduleType = "schedule_type" + case scheduleState = "schedule_state" + case sentOrReceivedOffGrid = "sent_or_received_off_grid" + case dateRecovered = "date_recovered" + case isTimeSensitive = "is_time_sensitive" + case ckChatID = "ck_chat_id" + case indexState = "index_state" + } +} + +enum ChatTable: IMDatabaseTable { + static let databaseTableName = "chat" + + enum Column: String, IMDatabaseColumn { + case rowID = "ROWID" + case guid + case style + case state + case accountID = "account_id" + case properties + case chatIdentifier = "chat_identifier" + case serviceName = "service_name" + case roomName = "room_name" + case accountLogin = "account_login" + case isArchived = "is_archived" + case lastAddressedHandle = "last_addressed_handle" + case displayName = "display_name" + case groupID = "group_id" + case isFiltered = "is_filtered" + case successfulQuery = "successful_query" + case engramID = "engram_id" + case serverChangeToken = "server_change_token" + case ckSyncState = "ck_sync_state" + case originalGroupID = "original_group_id" + case lastReadMessageTimestamp = "last_read_message_timestamp" + case cloudkitRecordID = "cloudkit_record_id" + case lastAddressedSIMID = "last_addressed_sim_id" + case isBlackholed = "is_blackholed" + case syndicationDate = "syndication_date" + case syndicationType = "syndication_type" + case isRecovered = "is_recovered" + case isDeletingIncomingMessages = "is_deleting_incoming_messages" + case isPendingReview = "is_pending_review" + } +} + +enum HandleTable: IMDatabaseTable { + static let databaseTableName = "handle" + + enum Column: String, IMDatabaseColumn { + case rowID = "ROWID" + case id + case country + case service + case uncanonicalizedID = "uncanonicalized_id" + case personCentricID = "person_centric_id" + } +} + +enum AttachmentTable: IMDatabaseTable { + static let databaseTableName = "attachment" + + enum Column: String, IMDatabaseColumn { + case rowID = "ROWID" + case guid + case createdDate = "created_date" + case startDate = "start_date" + case filename + case uti + case mimeType = "mime_type" + case transferState = "transfer_state" + case isOutgoing = "is_outgoing" + case userInfo = "user_info" + case transferName = "transfer_name" + case totalBytes = "total_bytes" + case isSticker = "is_sticker" + case stickerUserInfo = "sticker_user_info" + case attributionInfo = "attribution_info" + case hideAttachment = "hide_attachment" + case ckSyncState = "ck_sync_state" + case ckServerChangeTokenBlob = "ck_server_change_token_blob" + case ckRecordID = "ck_record_id" + case originalGUID = "original_guid" + case isCommSafetySensitive = "is_commsafety_sensitive" + case emojiImageContentIdentifier = "emoji_image_content_identifier" + case emojiImageShortDescription = "emoji_image_short_description" + case previewGenerationState = "preview_generation_state" + } +} + +enum ChatMessageJoinTable: IMDatabaseTable { + static let databaseTableName = "chat_message_join" + + enum Column: String, IMDatabaseColumn { + case chatID = "chat_id" + case messageID = "message_id" + case messageDate = "message_date" + case indexState = "index_state" + } +} + +enum ChatHandleJoinTable: IMDatabaseTable { + static let databaseTableName = "chat_handle_join" + + enum Column: String, IMDatabaseColumn { + case chatID = "chat_id" + case handleID = "handle_id" + } +} + +enum MessageAttachmentJoinTable: IMDatabaseTable { + static let databaseTableName = "message_attachment_join" + + enum Column: String, IMDatabaseColumn { + case messageID = "message_id" + case attachmentID = "attachment_id" + } +} diff --git a/src/IMessage/Sources/IMDatabase/Support/ChatRef.swift b/src/IMessage/Sources/IMDatabase/Support/ChatRef.swift deleted file mode 100644 index b68e46bd..00000000 --- a/src/IMessage/Sources/IMDatabase/Support/ChatRef.swift +++ /dev/null @@ -1,50 +0,0 @@ -/// Either a `ROWID` to an iMessage chat, its `guid` column (e.g. `iMessage;-;+17075551234`), -/// or both. -/// -/// This type shouldn't be used for identification and is solely a convenience -/// type vended by methods that return query results. -public enum ChatRef { - case justRowID(Int) - case justGUID(String) - case both(rowID: Int, guid: String) -} - -public extension ChatRef { - internal init?(rowID: Int?, guid: String?) { - if let rowID, let guid { - self = .both(rowID: rowID, guid: guid) - } else if let rowID { - self = .justRowID(rowID) - } else if let guid { - self = .justGUID(guid) - } else { - return nil - } - } - - var rowID: Int? { - switch self { - case let .justRowID(rowID): rowID - case let .both(rowID, _): rowID - default: nil - } - } - - var guid: String? { - switch self { - case let .justGUID(guid): guid - case let .both(_, guid): guid - default: nil - } - } -} - -extension ChatRef: Hashable { - public func hash(into hasher: inout Hasher) { - switch self { - case let .justRowID(rowID): hasher.combine(rowID) - case let .both(rowID, _): hasher.combine(rowID) - case let .justGUID(guid): hasher.combine(guid) - } - } -} diff --git a/src/IMessage/Sources/IMDatabase/Support/Column+.swift b/src/IMessage/Sources/IMDatabase/Support/Column+.swift index 3f1e7286..9458cc45 100644 --- a/src/IMessage/Sources/IMDatabase/Support/Column+.swift +++ b/src/IMessage/Sources/IMDatabase/Support/Column+.swift @@ -1,9 +1,29 @@ import Foundation -import SQLite +import GRDB -extension Column { - consuming func imCoreDate() throws -> Date? { - guard let nanoseconds = try optionalConverting(Int.self) else { +extension Row { + func optionalString(at index: Int) -> String? { + self[index] as String? + } + + func optionalInt(at index: Int) -> Int? { + self[index] as Int? + } + + func optionalData(at index: Int) -> Data? { + self[index] as Data? + } + + func requiredString(at index: Int) -> String { + self[index] as String + } + + func requiredInt(at index: Int) -> Int { + self[index] as Int + } + + func imCoreDate(at index: Int) -> Date? { + guard let nanoseconds = optionalInt(at: index) else { return nil } @@ -22,8 +42,8 @@ extension Column { return date } - consuming func looseBool() throws -> Bool { - guard let integer = try optionalConverting(Int.self) else { + func looseBool(at index: Int) -> Bool { + guard let integer = optionalInt(at: index) else { return false } diff --git a/src/IMessage/Sources/IMDatabaseTestBench/TestBench.swift b/src/IMessage/Sources/IMDatabaseTestBench/TestBench.swift index 4cffdec8..badaabb1 100644 --- a/src/IMessage/Sources/IMDatabaseTestBench/TestBench.swift +++ b/src/IMessage/Sources/IMDatabaseTestBench/TestBench.swift @@ -186,14 +186,12 @@ extension TestBench { bootstrap(logLevel: options.logLevel) let db = try IMDatabase() - let states = try Dictionary(uniqueKeysWithValues: db.chatStates().map { chatRef, state in - (chatRef.rowID!, state) - }) + let states = try db.chatStates() for chat in try db.chats() where filter.allSatisfy({ $0.test(against: chat) }) { chat.dump() - if let state = states[chat.id] { + if let state = states[chat.guid.description] { if #available(macOS 12, *) { let relativeDate = state.lastReadMessageTimestamp.formatted(.relative(presentation: .numeric, unitsStyle: .wide)) print("- \(state) (\(relativeDate))") @@ -250,9 +248,9 @@ extension TestBench { let newStates = try db.chatStates() defer { states = newStates } - var changedChatStates: [ChatRef: ChatState] = [:] - for (chatID, newState) in newStates where states[chatID] != newState { - changedChatStates[chatID] = newState + var changedChatStates: [String: ChatState] = [:] + for (chatGUID, newState) in newStates where states[chatGUID] != newState { + changedChatStates[chatGUID] = newState } print("changed unread states:", changedChatStates) diff --git a/src/IMessage/Sources/IMDatabaseTests/LiveSQLTests.swift b/src/IMessage/Sources/IMDatabaseTests/LiveSQLTests.swift new file mode 100644 index 00000000..b02e98a2 --- /dev/null +++ b/src/IMessage/Sources/IMDatabaseTests/LiveSQLTests.swift @@ -0,0 +1,803 @@ +import Foundation +import GRDB +@testable import IMDatabase +import Testing + +private enum LocalMessagesDatabase { + static var messagesDirectory: URL { + if let override = ProcessInfo.processInfo.environment["IMDATABASE_TEST_MESSAGES_DIR"], !override.isEmpty { + return URL(fileURLWithPath: override, isDirectory: true) + } + return URL(fileURLWithPath: "\(NSHomeDirectory())/Library/Messages/", isDirectory: true) + } + + static var chatDBURL: URL { + messagesDirectory.appendingPathComponent("chat.db") + } + + static var isReadable: Bool { + FileManager.default.isReadableFile(atPath: chatDBURL.path) + } + + static let fullDiskAccessRequest: Void = { + guard !isReadable else { return } + do { + try Process.run( + URL(fileURLWithPath: "/usr/bin/open"), + arguments: ["x-apple.systempreferences:com.apple.preference.security?Privacy_AllFiles"] + ) + } catch {} + }() + + static func requireReadable() throws { + guard isReadable else { + _ = fullDiskAccessRequest + throw FullDiskAccessRequired(chatDBURL: chatDBURL) + } + } + + static func imDatabase() throws -> IMDatabase { + try requireReadable() + return try IMDatabase(messagesDataBaseURL: messagesDirectory) + } + + static func queue() throws -> DatabaseQueue { + try requireReadable() + var configuration = Configuration() + configuration.readonly = true + return try DatabaseQueue(path: chatDBURL.path, configuration: configuration) + } +} + +private struct FullDiskAccessRequired: Error, CustomStringConvertible { + var chatDBURL: URL + + var description: String { + """ + IMDatabaseLiveSQLTests need Full Disk Access to read \(chatDBURL.path). + + The test runner opened System Settings > Privacy & Security > Full Disk Access. + Grant access to the app that launched these tests, then rerun the suite. + + Common cases: + - Xcode test run: grant Xcode Full Disk Access. + - Terminal swift test: grant that terminal app Full Disk Access. + - Codex/local tool run: grant the host app Full Disk Access. + + Override the Messages directory with IMDATABASE_TEST_MESSAGES_DIR if needed. + """ + } +} + +private struct SampleChat { + var rowID: Int + var guid: String + var latestMessageDate: Int? +} + +private struct SampleMessage { + var rowID: Int + var guid: String + var chatRowID: Int + var chatGUID: String + var messageDate: Int? +} + +@Suite("IMDatabase live SQL", .serialized) +struct IMDatabaseLiveSQLTests { + @Test("loads schema from local chat.db") + func schemaLoadsAllKnownTables() throws { + let db = try LocalMessagesDatabase.imDatabase() + let schema = try db.schema() + + #expect(schema.sqliteSequence.has(.name)) + #expect(schema.sqliteSequence.has(.seq)) + #expect(schema.message.has(.rowID)) + #expect(schema.message.has(.guid)) + #expect(schema.message.has(.date)) + #expect(schema.message.has(.dateRead)) + #expect(schema.chat.has(.rowID)) + #expect(schema.chat.has(.guid)) + #expect(schema.chat.has(.serviceName)) + #expect(schema.handle.has(.rowID)) + #expect(schema.handle.has(.id)) + #expect(schema.attachment.has(.rowID)) + #expect(schema.attachment.has(.guid)) + #expect(schema.chatMessageJoin.has(.chatID)) + #expect(schema.chatMessageJoin.has(.messageID)) + #expect(schema.chatHandleJoin.has(.chatID)) + #expect(schema.chatHandleJoin.has(.handleID)) + #expect(schema.messageAttachmentJoin.has(.messageID)) + #expect(schema.messageAttachmentJoin.has(.attachmentID)) + } + + @Test("matches account, chat, thread GUID, and participant queries") + func accountAndChatQueriesMatchRawSQL() throws { + let db = try LocalMessagesDatabase.imDatabase() + let queue = try LocalMessagesDatabase.queue() + + let expectedAccountLogins = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT DISTINCT account_login + FROM chat + """).compactMap { $0[0] as String? }.sorted() + } + #expect(try db.accountLogins().sorted() == expectedAccountLogins) + + let expectedChatCount = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: "SELECT COUNT(*) FROM chat WHERE guid IS NOT NULL") ?? 0 + } + #expect(try db.chats().count == expectedChatCount) + + let chat = try #require(try Self.latestChat(queue: queue)) + let fetchedChat = try #require(try db.chat(withGUID: chat.guid)) + #expect(fetchedChat.id == chat.rowID) + #expect(fetchedChat.guid.description == chat.guid) + + let expectedThreadGUIDs = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: "SELECT guid FROM chat").compactMap { $0[0] as String? }.sorted() + } + #expect(try db.allThreadGUIDs().sorted() == expectedThreadGUIDs) + + let expectedHandles = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: """ + SELECT COUNT(*) + FROM chat + INNER JOIN chat_handle_join ON chat_handle_join.chat_id = chat.ROWID + INNER JOIN handle ON handle.ROWID = chat_handle_join.handle_id + WHERE chat.guid = ? + """, arguments: [chat.guid]) ?? 0 + } + #expect(try db.handles(inChatWithGUID: chat.guid).count == expectedHandles) + } + + @Test("matches legacy message and attachment queries") + func legacyMessageAndAttachmentQueriesMatchRawSQL() throws { + let db = try LocalMessagesDatabase.imDatabase() + let queue = try LocalMessagesDatabase.queue() + let sample = try #require(try Self.latestMessage(queue: queue)) + + let fetched = try #require(try db.message( + with: GUID(stringLiteral: sample.guid), + withAttachments: true + )) + #expect(fetched.message.id == sample.rowID) + #expect(fetched.message.guid.description == sample.guid) + #expect(fetched.chatGUID.description == sample.chatGUID) + #expect(fetched.message.attachments != nil) + + let expectedAttachmentCount = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: """ + SELECT COUNT(*) + FROM message_attachment_join AS maj + INNER JOIN attachment AS a ON a.ROWID = maj.attachment_id + WHERE maj.message_id = ? + """, arguments: [sample.rowID]) ?? 0 + } + #expect(fetched.message.attachments?.count == expectedAttachmentCount) + + let expectedMessageIDs = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT m.ROWID + FROM message AS m + LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + LEFT JOIN chat AS c ON cmj.chat_id = c.ROWID + WHERE c.guid = ? + ORDER BY m.date DESC + LIMIT 5 + """, arguments: [sample.chatGUID]).compactMap { $0[0] as Int? } + } + let messages = try Array(db.messages( + in: GUID(stringLiteral: sample.chatGUID), + order: .newestFirst, + limit: 5, + withAttachments: true + )) + #expect(messages.map(\.id) == expectedMessageIDs) + #expect(messages.allSatisfy { $0.attachments != nil }) + + if let messageDate = sample.messageDate { + let beforeDate = Date(nanosecondsSinceReferenceDate: messageDate) + let expectedBeforeIDs = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT m.ROWID + FROM message AS m + LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + LEFT JOIN chat AS c ON cmj.chat_id = c.ROWID + WHERE c.guid = ? AND m.date < ? + ORDER BY m.date DESC + LIMIT 3 + """, arguments: databaseArguments([sample.chatGUID, messageDate])).compactMap { $0[0] as Int? } + } + let beforeMessages = try Array(db.messages( + in: GUID(stringLiteral: sample.chatGUID), + filter: .before(beforeDate), + order: .newestFirst, + limit: 3, + withAttachments: false + )) + #expect(beforeMessages.map(\.id) == expectedBeforeIDs) + + let expectedAfterIDs = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT m.ROWID + FROM message AS m + LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + LEFT JOIN chat AS c ON cmj.chat_id = c.ROWID + WHERE c.guid = ? AND m.date > ? + ORDER BY m.date ASC + LIMIT 3 + """, arguments: databaseArguments([sample.chatGUID, messageDate])).compactMap { $0[0] as Int? } + } + let afterMessages = try Array(db.messages( + in: GUID(stringLiteral: sample.chatGUID), + filter: .after(beforeDate), + order: .oldestFirst, + limit: 3, + withAttachments: false + )) + #expect(afterMessages.map(\.id) == expectedAfterIDs) + } + } + + @Test("matches mapped thread queries") + func mappedThreadQueriesMatchRawSQL() throws { + let db = try LocalMessagesDatabase.imDatabase() + let queue = try LocalMessagesDatabase.queue() + + let expectedThreadIDs = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT chat.ROWID + FROM chat + ORDER BY (SELECT MAX(message_date) FROM chat_message_join WHERE chat_id = chat.ROWID) DESC + LIMIT 10 + """).compactMap { $0[0] as Int? } + } + let threadRows = try db.mappedThreadRows(cursor: nil, direction: nil, limit: 10) + #expect(threadRows.map(\.rowID) == expectedThreadIDs) + + if let cursor = threadRows.dropFirst().first?.msgDate { + let rowsBeforeCursor = try db.mappedThreadRows(cursor: String(cursor), direction: .before, limit: 5) + #expect(rowsBeforeCursor.allSatisfy { ($0.msgDate ?? Int.min) < cursor }) + + let rowsAfterCursor = try db.mappedThreadRows(cursor: String(cursor), direction: .after, limit: 5) + #expect(rowsAfterCursor.allSatisfy { ($0.msgDate ?? Int.max) > cursor }) + } + + let chat = try #require(try Self.latestChat(queue: queue)) + let threadRow = try #require(try db.mappedThreadRow(guid: chat.guid)) + #expect(threadRow.rowID == chat.rowID) + #expect(threadRow.guid == chat.guid) + + let chatRowIDs = Array(threadRows.prefix(5).map(\.rowID)) + let participantRows = try db.mappedThreadParticipantRows(chatRowIDs: chatRowIDs) + let expectedParticipantCounts = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT chj.chat_id, COUNT(*) + FROM handle + LEFT JOIN chat_handle_join AS chj ON chj.handle_id = handle.ROWID + WHERE chj.chat_id IN (\(placeholders(count: chatRowIDs.count))) + GROUP BY chj.chat_id + """, arguments: StatementArguments(chatRowIDs)).reduce(into: [:]) { result, row in + result[row[0] as Int] = row[1] as Int + } + } + #expect(participantRows.mapValues(\.count) == expectedParticipantCounts) + + let unreadCounts = try db.mappedUnreadCounts(chatRowIDs: chatRowIDs) + let expectedUnreadCounts = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT cm.chat_id, COUNT(cm.chat_id) + FROM message AS m + INNER JOIN chat_message_join AS cm ON m.ROWID = cm.message_id + WHERE m.item_type == 0 + AND m.is_read == 0 + AND m.is_from_me == 0 + AND cm.chat_id IN (\(placeholders(count: chatRowIDs.count))) + GROUP BY cm.chat_id + """, arguments: StatementArguments(chatRowIDs)).reduce(into: [:]) { result, row in + result[row[0] as Int] = row[1] as Int + } + } + #expect(unreadCounts == expectedUnreadCounts) + } + + @Test("matches mapped message paging and batch queries") + func mappedMessageQueriesMatchRawSQL() throws { + let db = try LocalMessagesDatabase.imDatabase() + let queue = try LocalMessagesDatabase.queue() + let chat = try #require(try Self.chatWithAtLeastMessages(queue: queue, count: 3)) + + #expect(try db.mappedChatRowID(guid: chat.guid) == chat.rowID) + + let expectedNewestMessageIDs = try Self.messageIDs( + queue: queue, + chatRowID: chat.rowID, + order: "DESC", + limit: 5 + ) + let messageRows = try db.mappedMessageRows(in: chat.guid, cursor: nil, direction: nil, limit: 5) + #expect(messageRows.map(\.rowID) == expectedNewestMessageIDs) + #expect(messageRows.allSatisfy { $0.threadID == chat.guid }) + + let cursor = try #require(try Self.messageCursor(queue: queue, chatRowID: chat.rowID, offset: 1)) + let expectedBeforeCursor = try Self.messageIDs( + queue: queue, + chatRowID: chat.rowID, + cursorSQL: "AND cmj.message_date < ?", + cursor: cursor, + order: "DESC", + limit: 5 + ) + let beforeCursorRows = try db.mappedMessageRows(in: chat.guid, cursor: String(cursor), direction: .before, limit: 5) + #expect(beforeCursorRows.map(\.rowID) == expectedBeforeCursor) + + let dateExpression = try db.schema().message.has(.dateEdited) + ? "MAX(m.date, COALESCE(m.date_edited, 0))" + : "cmj.message_date" + let expectedAfterCursor = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT cmj.message_id + FROM chat_message_join AS cmj + INNER JOIN message AS m ON m.ROWID = cmj.message_id + WHERE cmj.chat_id = ? AND \(dateExpression) > ? + ORDER BY cmj.message_date ASC, cmj.message_id ASC + LIMIT 5 + """, arguments: databaseArguments([chat.rowID, cursor])).compactMap { $0[0] as Int? } + } + let afterCursorRows = try db.mappedMessageRows(in: chat.guid, cursor: String(cursor), direction: .after, limit: 5) + #expect(afterCursorRows.map(\.rowID) == expectedAfterCursor) + + let expectedRowsByGUID = Array(messageRows.prefix(3)) + let rowsByGUID = try db.mappedMessageRows(guids: expectedRowsByGUID.map(\.guid) + [expectedRowsByGUID[0].guid]) + #expect(Set(rowsByGUID.map(\.rowID)) == Set(expectedRowsByGUID.map(\.rowID))) + + let rowsByID = try db.mappedMessageRows(rowIDs: expectedRowsByGUID.map(\.rowID) + [expectedRowsByGUID[0].rowID]) + let expectedRowsByID = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT m.ROWID + FROM message AS m + WHERE m.ROWID IN (\(placeholders(count: expectedRowsByGUID.count))) + ORDER BY m.date DESC + """, arguments: StatementArguments(expectedRowsByGUID.map(\.rowID))).compactMap { $0[0] as Int? } + } + #expect(rowsByID.map(\.rowID) == expectedRowsByID) + + let latestRows = try db.mappedLatestMessageRows(chatRowIDs: [chat.rowID]) + let latestRow = try #require(latestRows[chat.guid]) + #expect(latestRow.rowID == expectedNewestMessageIDs.first) + } + + @Test("matches mapped attachment and reaction queries") + func mappedAttachmentAndReactionQueriesMatchRawSQL() throws { + let db = try LocalMessagesDatabase.imDatabase() + let queue = try LocalMessagesDatabase.queue() + + if let attachmentSample = try Self.messageWithAttachment(queue: queue) { + let attachmentRows = try db.mappedAttachmentRows(messageRowIDs: [attachmentSample.messageRowID]) + let expectedAttachmentRows = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT m.ROWID AS msgRowID, a.filename, a.transfer_name, a.total_bytes, a.is_sticker, a.guid AS attachmentID, a.transfer_state + FROM message AS m + LEFT JOIN message_attachment_join AS maj ON maj.message_id = m.ROWID + LEFT JOIN attachment AS a ON a.ROWID = maj.attachment_id + WHERE m.ROWID = ? + """, arguments: [attachmentSample.messageRowID]) + } + #expect(attachmentRows.count == expectedAttachmentRows.count) + #expect(attachmentRows.first?.msgRowID == attachmentSample.messageRowID) + #expect(try db.attachmentFilename(guid: attachmentSample.attachmentGUID) == attachmentSample.filename) + #expect(try db.attachmentFilename(messageRowID: attachmentSample.messageRowID) == attachmentSample.filename) + } + + if let reactionSample = try Self.reactionTarget(queue: queue) { + let reactionRows = try db.mappedReactionRows( + messageGUIDs: [reactionSample.targetGUID], + chatRowIDs: [reactionSample.chatRowID] + ) + let expectedReactionRowIDs = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT m.ROWID + FROM message AS m + LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + WHERE REPLACE(SUBSTR(associated_message_guid, INSTR(associated_message_guid, '/') + 1), 'bp:', '') = ? + AND cmj.chat_id = ? + ORDER BY m.ROWID ASC + """, arguments: databaseArguments([reactionSample.targetGUID, reactionSample.chatRowID])).compactMap { $0[0] as Int? } + } + #expect(reactionRows.map(\.rowID) == expectedReactionRowIDs) + + let chatGUID = try queue.read { rawDB in + try String.fetchOne(rawDB, sql: "SELECT guid FROM chat WHERE ROWID = ?", arguments: [reactionSample.chatRowID]) + } + if let chatGUID { + let reactionRowsByGUID = try db.mappedReactionRows(messageGUIDs: [reactionSample.targetGUID], chatGUID: chatGUID) + #expect(reactionRowsByGUID.map(\.rowID) == expectedReactionRowIDs) + } + } + } + + @Test("matches unread state, update cursor, sent message, and delta queries") + func updateAndUnreadQueriesMatchRawSQL() throws { + let db = try LocalMessagesDatabase.imDatabase() + let queue = try LocalMessagesDatabase.queue() + + let expectedStates: [String: (unreadCount: Int, lastRead: Int)] = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: unreadStatesQuery).reduce(into: [:]) { result, row in + let guid: String = row[0] + result[guid] = ( + unreadCount: row[1], + lastRead: row[2] + ) + } + } + let states = try db.chatStates() + #expect(states.count == expectedStates.count) + for (guid, expected) in expectedStates { + let state = try #require(states[guid]) + #expect(state.unreadCount == expected.unreadCount) + expectClose(state.lastReadMessageTimestamp.nanosecondsSinceReferenceDate, expected.lastRead) + } + + let unreadSample = try #require(try Self.latestChat(queue: queue)) + let sampleUnreadCounts = try db.mappedUnreadCounts(chatRowIDs: [unreadSample.rowID]) + #expect(try db.isThreadRead(chatGUID: unreadSample.guid) == ((sampleUnreadCounts[unreadSample.rowID] ?? 0) == 0)) + + let rawLastMessageRowID = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: "SELECT seq FROM sqlite_sequence WHERE name = 'message'") ?? 0 + } + #expect(try db.lastMessageRowID() == rawLastMessageRowID) + + let rawMaxDateRead = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: "SELECT MAX(date_read) FROM message") ?? 0 + } + let rawMaxDateEdited = try db.schema().message.has(.dateEdited) + ? queue.read { rawDB in + try Int.fetchOne(rawDB, sql: "SELECT MAX(date_edited) FROM message") ?? 0 + } + : 0 + expectClose(try db.maxMessageDateRead().nanosecondsSinceReferenceDate, rawMaxDateRead) + + let snapshot = try db.messageUpdateCursorSnapshot() + #expect(snapshot.lastRowID == rawLastMessageRowID) + expectClose(snapshot.lastDateRead.nanosecondsSinceReferenceDate, rawMaxDateRead) + expectClose(snapshot.lastDateEdited.nanosecondsSinceReferenceDate, rawMaxDateEdited) + + let threshold = max(0, rawLastMessageRowID - 1000) + let expectedSent = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT ROWID, guid + FROM message + WHERE is_from_me = 1 AND ROWID > ? + """, arguments: [threshold]).map { row in + (rowID: row[0] as Int, guid: row[1] as String) + } + } + let sent = try db.sentMessageIDs(since: threshold) + #expect(sent.map(\.rowID) == expectedSent.map(\.rowID)) + + let sample = try #require(try Self.latestMessage(queue: queue)) + #expect(try db.threadIDForMessage(rowID: sample.rowID) == sample.chatGUID) + + let deltas = try db.messages( + newerThanRowID: threshold, + orReadSince: Date(nanosecondsSinceReferenceDate: rawMaxDateRead), + orEditedSince: Date(nanosecondsSinceReferenceDate: rawMaxDateEdited) + ) + let expectedDeltaRowIDs = try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT m.ROWID + FROM message AS m + LEFT JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + LEFT JOIN chat AS c ON cmj.chat_id = c.ROWID + WHERE m.ROWID > ? AND c.guid IS NOT NULL + ORDER BY m.ROWID ASC + """, arguments: [threshold]).compactMap { $0[0] as Int? } + } + #expect(deltas.updatedMessages.map(\.rowID) == expectedDeltaRowIDs) + } + + @Test("search results point at messages containing the query") + func searchMessagesReturnsMatchingRows() throws { + let db = try LocalMessagesDatabase.imDatabase() + let queue = try LocalMessagesDatabase.queue() + let query = "a" + + let rowIDs = try db.searchMessages(query: query, limit: 10) + #expect(rowIDs.count <= 10) + + for rowID in rowIDs { + let matches = try Self.messageTextMatches(rowID: rowID, query: query, queue: queue) + #expect(matches) + } + + let sample = try #require(try Self.latestMessage(queue: queue)) + let chatFilteredRowIDs = try db.searchMessages(query: query, chatGUID: sample.chatGUID, limit: 5) + for rowID in chatFilteredRowIDs { + let belongsToChat = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: """ + SELECT COUNT(*) + FROM chat_message_join AS cmj + INNER JOIN chat AS c ON c.ROWID = cmj.chat_id + WHERE cmj.message_id = ? AND c.guid = ? + """, arguments: databaseArguments([rowID, sample.chatGUID])) ?? 0 + } + #expect(belongsToChat > 0) + #expect(try Self.messageTextMatches(rowID: rowID, query: query, queue: queue)) + } + + let mediaRowIDs = try db.searchMessages(query: query, mediaOnly: true, limit: 5) + for rowID in mediaRowIDs { + let hasAttachments = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: "SELECT cache_has_attachments FROM message WHERE ROWID = ?", arguments: [rowID]) ?? 0 + } + #expect(hasAttachments == 1) + #expect(try Self.messageTextMatches(rowID: rowID, query: query, queue: queue)) + } + + for sender in ["me", "others"] { + let senderRowIDs = try db.searchMessages(query: query, sender: sender, limit: 5) + for rowID in senderRowIDs { + let isFromMe = try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: "SELECT is_from_me FROM message WHERE ROWID = ?", arguments: [rowID]) ?? -1 + } + #expect(isFromMe == (sender == "me" ? 1 : 0)) + #expect(try Self.messageTextMatches(rowID: rowID, query: query, queue: queue)) + } + } + } + + private static func messageTextMatches(rowID: Int, query: String, queue: DatabaseQueue) throws -> Bool { + try queue.read { rawDB in + let row = try #require(try Row.fetchOne(rawDB, sql: """ + SELECT text, attributedBody + FROM message + WHERE ROWID = ? + """, arguments: [rowID])) + let plainText = row[0] as String? + let attributedBody = row[1] as Data? + let decodedText = attributedBody.flatMap { try? AttributedBodyDecoder.plainText(from: $0) } + let text = decodedText?.isEmpty == false ? decodedText : plainText + return text?.lowercased().contains(query) == true + } + } + + @Test("benchmarks hot local SQL paths") + func benchmarkHotLocalSQLPaths() throws { + let db = try LocalMessagesDatabase.imDatabase() + let iterations = max(1, Int(ProcessInfo.processInfo.environment["IMDATABASE_BENCHMARK_ITERATIONS"] ?? "") ?? 5) + let queue = try LocalMessagesDatabase.queue() + let chat = try #require(try Self.chatWithAtLeastMessages(queue: queue, count: 3)) + let threadRows = try db.mappedThreadRows(cursor: nil, direction: nil, limit: 25) + let chatRowIDs = Array(threadRows.prefix(25).map(\.rowID)) + let messageRows = try db.mappedMessageRows(in: chat.guid, cursor: nil, direction: nil, limit: 25) + let messageRowIDs = messageRows.map(\.rowID) + let messageGUIDs = messageRows.map(\.guid) + let reactionSample = try Self.reactionTarget(queue: queue) + + try measure("mappedThreadRows", iterations: iterations) { + try db.mappedThreadRows(cursor: nil, direction: nil, limit: 25).count + } + try measure("mappedLatestMessageRows", iterations: iterations) { + try db.mappedLatestMessageRows(chatRowIDs: chatRowIDs).count + } + try measure("mappedThreadParticipantRows", iterations: iterations) { + try db.mappedThreadParticipantRows(chatRowIDs: chatRowIDs).values.reduce(0) { $0 + $1.count } + } + try measure("mappedUnreadCounts", iterations: iterations) { + try db.mappedUnreadCounts(chatRowIDs: chatRowIDs).count + } + try measure("mappedMessageRows.page", iterations: iterations) { + try db.mappedMessageRows(in: chat.guid, cursor: nil, direction: nil, limit: 25).count + } + try measure("mappedMessageRows.rowIDs", iterations: iterations) { + try db.mappedMessageRows(rowIDs: messageRowIDs).count + } + try measure("mappedMessageRows.guids", iterations: iterations) { + try db.mappedMessageRows(guids: messageGUIDs).count + } + try measure("mappedAttachmentRows", iterations: iterations) { + try db.mappedAttachmentRows(messageRowIDs: messageRowIDs).count + } + try measure("mappedReactionRows", iterations: iterations) { + if let reactionSample { + return try db.mappedReactionRows( + messageGUIDs: [reactionSample.targetGUID], + chatRowID: reactionSample.chatRowID + ).count + } + return try db.mappedReactionRows(messageGUIDs: messageGUIDs, chatRowID: chat.rowID).count + } + try measure("messageUpdateCursorSnapshot", iterations: iterations) { + let snapshot = try db.messageUpdateCursorSnapshot() + return snapshot.lastRowID + } + try measure("chatStates", iterations: iterations) { + try db.chatStates().count + } + try measure("searchMessages", iterations: iterations) { + try db.searchMessages(query: "a", limit: 20).count + } + } +} + +private extension IMDatabaseLiveSQLTests { + static func latestChat(queue: DatabaseQueue) throws -> SampleChat? { + try queue.read { rawDB in + try Row.fetchOne(rawDB, sql: """ + SELECT c.ROWID, c.guid, MAX(cmj.message_date) AS latestMessageDate + FROM chat AS c + LEFT JOIN chat_message_join AS cmj ON cmj.chat_id = c.ROWID + WHERE c.guid IS NOT NULL + GROUP BY c.ROWID + ORDER BY latestMessageDate DESC + LIMIT 1 + """).map { row in + SampleChat(rowID: row[0] as Int, guid: row[1] as String, latestMessageDate: row[2] as Int?) + } + } + } + + static func chatWithAtLeastMessages(queue: DatabaseQueue, count: Int) throws -> SampleChat? { + try queue.read { rawDB in + try Row.fetchOne(rawDB, sql: """ + SELECT c.ROWID, c.guid, MAX(cmj.message_date) AS latestMessageDate + FROM chat AS c + INNER JOIN chat_message_join AS cmj ON cmj.chat_id = c.ROWID + WHERE c.guid IS NOT NULL + GROUP BY c.ROWID + HAVING COUNT(cmj.message_id) >= ? + ORDER BY latestMessageDate DESC + LIMIT 1 + """, arguments: [count]).map { row in + SampleChat(rowID: row[0] as Int, guid: row[1] as String, latestMessageDate: row[2] as Int?) + } + } + } + + static func latestMessage(queue: DatabaseQueue) throws -> SampleMessage? { + try queue.read { rawDB in + try Row.fetchOne(rawDB, sql: """ + SELECT m.ROWID, m.guid, c.ROWID, c.guid, cmj.message_date + FROM message AS m + INNER JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + INNER JOIN chat AS c ON c.ROWID = cmj.chat_id + WHERE m.guid IS NOT NULL AND c.guid IS NOT NULL + ORDER BY m.date DESC + LIMIT 1 + """).map { row in + SampleMessage( + rowID: row[0] as Int, + guid: row[1] as String, + chatRowID: row[2] as Int, + chatGUID: row[3] as String, + messageDate: row[4] as Int? + ) + } + } + } + + static func messageCursor(queue: DatabaseQueue, chatRowID: Int, offset: Int) throws -> Int? { + try queue.read { rawDB in + try Int.fetchOne(rawDB, sql: """ + SELECT cmj.message_date + FROM chat_message_join AS cmj + INNER JOIN message AS m ON m.ROWID = cmj.message_id + WHERE cmj.chat_id = ? + ORDER BY cmj.message_date DESC, cmj.message_id DESC + LIMIT 1 OFFSET \(offset) + """, arguments: [chatRowID]) + } + } + + static func messageIDs( + queue: DatabaseQueue, + chatRowID: Int, + cursorSQL: String = "", + cursor: Int? = nil, + order: String, + limit: Int + ) throws -> [Int] { + var arguments: [any DatabaseValueConvertible] = [chatRowID] + if let cursor { + arguments.append(cursor) + } + arguments.append(limit) + return try queue.read { rawDB in + try Row.fetchAll(rawDB, sql: """ + SELECT cmj.message_id + FROM chat_message_join AS cmj + INNER JOIN message AS m ON m.ROWID = cmj.message_id + WHERE cmj.chat_id = ? + \(cursorSQL) + ORDER BY cmj.message_date \(order), cmj.message_id \(order) + LIMIT ? + """, arguments: StatementArguments(arguments)).compactMap { $0[0] as Int? } + } + } + + static func messageWithAttachment(queue: DatabaseQueue) throws -> (messageRowID: Int, attachmentGUID: String, filename: String?)? { + try queue.read { rawDB in + try Row.fetchOne(rawDB, sql: """ + SELECT m.ROWID, a.guid, a.filename + FROM message AS m + INNER JOIN message_attachment_join AS maj ON maj.message_id = m.ROWID + INNER JOIN attachment AS a ON a.ROWID = maj.attachment_id + WHERE a.guid IS NOT NULL + ORDER BY m.date DESC + LIMIT 1 + """).map { row in + (messageRowID: row[0] as Int, attachmentGUID: row[1] as String, filename: row[2] as String?) + } + } + } + + static func reactionTarget(queue: DatabaseQueue) throws -> (targetGUID: String, chatRowID: Int)? { + try queue.read { rawDB in + try Row.fetchOne(rawDB, sql: """ + SELECT normalized_target_guid, chat_id + FROM ( + SELECT + REPLACE(SUBSTR(m.associated_message_guid, INSTR(m.associated_message_guid, '/') + 1), 'bp:', '') AS normalized_target_guid, + cmj.chat_id AS chat_id + FROM message AS m + INNER JOIN chat_message_join AS cmj ON cmj.message_id = m.ROWID + WHERE m.associated_message_guid IS NOT NULL + ) + WHERE normalized_target_guid IS NOT NULL AND normalized_target_guid != '' + GROUP BY normalized_target_guid, chat_id + ORDER BY COUNT(*) DESC + LIMIT 1 + """).map { row in + (targetGUID: row[0] as String, chatRowID: row[1] as Int) + } + } + } +} + +private func placeholders(count: Int) -> String { + Array(repeating: "?", count: count).joined(separator: ", ") +} + +private func databaseArguments(_ values: [Any]) -> StatementArguments { + guard let arguments = StatementArguments(values) else { + preconditionFailure("all test SQL arguments must be database values") + } + return arguments +} + +private func expectClose(_ actual: Int, _ expected: Int, tolerance: Int = 1_000_000) { + #expect(abs(actual - expected) <= tolerance) +} + +private func measure(_ name: String, iterations: Int, _ operation: () throws -> Int) throws { + let clock = ContinuousClock() + _ = try operation() + + var samples: [Double] = [] + var resultCount = 0 + for _ in 0.. Double { + let components = duration.components + return Double(components.seconds) * 1000 + Double(components.attoseconds) / 1_000_000_000_000_000 +} diff --git a/src/IMessage/Sources/IMessage/EventWatcher/ChatRef+Description.swift b/src/IMessage/Sources/IMessage/EventWatcher/ChatRef+Description.swift deleted file mode 100644 index d849591b..00000000 --- a/src/IMessage/Sources/IMessage/EventWatcher/ChatRef+Description.swift +++ /dev/null @@ -1,12 +0,0 @@ -import IMDatabase - -// A bit gross, but `IMDatabase` shouldn't know what a hasher is. -extension ChatRef: CustomStringConvertible { - public var description: String { - if let guid { - Hasher.thread.tokenizeRemembering(pii: guid) - } else { - "chat#\(rowID!)" - } - } -} diff --git a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Unreads.swift b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Unreads.swift index ea461544..757871e3 100644 --- a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Unreads.swift +++ b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Unreads.swift @@ -13,30 +13,25 @@ extension EventWatcher { /// Diffs current chat states against the previous snapshot and returns events for any changes. func diffChatStates() throws -> [ServerEvent] { // Grab the latest set, and remember it for the next database change. - let currentChatStates: [ChatRef: ChatState] = try db.chatStates() + let currentChatStates: [String: ChatState] = try db.chatStates() var eventsToSend = [ServerEvent]() var changes = 0 - for (chatRef, currentState) in currentChatStates { - guard chatStates[chatRef]?.state != currentState else { + for (chatGUID, currentState) in currentChatStates { + guard chatStates[chatGUID]?.state != currentState else { // Unread state didn't change, so a state sync is unnecessary. continue } defer { changes += 1 } - guard let guid = chatRef.guid else { - log.error("didn't receive a guid for chat that underwent an unread state change") - continue - } - // Minting a new timestamped chat state like this also ensures // that we handle new (not just updated) chats correctly. let fresh = TimestampedChatState(minting: currentState) - chatStates[chatRef] = fresh + chatStates[chatGUID] = fresh - let hashedThreadID = Hasher.thread.tokenizeRemembering(pii: guid) + let hashedThreadID = Hasher.thread.tokenizeRemembering(pii: chatGUID) let lastReadMessageSortKey = (currentState.lastReadMessageTimestamp.timeIntervalSince1970 * 1000).rounded() let isUnread = currentState.unreadCount > 0 let markedUnreadUpdatedAt = Int(fresh.lastUpdated.timeIntervalSince1970 * 1000) @@ -74,7 +69,7 @@ extension EventWatcher { "markedUnreadUpdatedAt": markedUnreadUpdatedAt, ] - traceUnreads("chat \(chatRef) patch: lastReadMessageSortKey=\(lastReadMessageSortKey), isMarkedUnread=\(isUnread), markedUnreadUpdatedAt=\(markedUnreadUpdatedAt)") + traceUnreads("chat \(hashedThreadID) patch: lastReadMessageSortKey=\(lastReadMessageSortKey), isMarkedUnread=\(isUnread), markedUnreadUpdatedAt=\(markedUnreadUpdatedAt)") if currentState.unreadCount == 0 { // Sync the fact that the thread became read. This is especially @@ -90,21 +85,17 @@ extension EventWatcher { eventsToSend.append(ServerEvent.stateSyncThread(id: hashedThreadID, patch: patch)) - traceUnreads("chat \(chatRef) unread state changed to: \(fresh)") + traceUnreads("chat \(hashedThreadID) unread state changed to: \(fresh)") } traceUnreads("\(changes) unread state(s) changed this time around") // Detect chats that were deleted from iMessage since the last database change. let deletedChats = chatStates.keys.filter { currentChatStates[$0] == nil } - let deletedThreadIDs = deletedChats.compactMap { chat -> String? in - chatStates.removeValue(forKey: chat) - guard let guid = chat.guid else { - log.error("deleted chat didn't have a guid, can't emit a delete event") - return nil - } - log.info("chat \(guid) was deleted from iMessage") - return Hasher.thread.tokenizeRemembering(pii: guid) + let deletedThreadIDs = deletedChats.map { chatGUID -> String in + chatStates.removeValue(forKey: chatGUID) + log.info("chat \(chatGUID) was deleted from iMessage") + return Hasher.thread.tokenizeRemembering(pii: chatGUID) } if !deletedThreadIDs.isEmpty { diff --git a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Updates.swift b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Updates.swift index ed527018..f8adca42 100644 --- a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Updates.swift +++ b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher+Updates.swift @@ -1,5 +1,7 @@ +import Collections import Foundation import IMDatabase +import IMessageCore import Logging import PlatformSDK @@ -10,18 +12,23 @@ private func traceMessageUpdates(_ message: @autoclosure () -> Logger.Message) { log.debug(message()) } -private func threadRefreshEvents(forUpdatedChats latest: UpdatedChatsQueryResult) throws -> [ServerEvent] { - guard !latest.updatedChats.isEmpty else { return [] } +private enum PendingMessageKind { + case reactionAdd + case normal(UpdatedMessageChange) +} - return latest.updatedChats.compactMap { chat in - guard let guid = chat.guid else { - log.error("updated chat didn't have a guid, not vending refresh event") - return nil - } - traceMessageUpdates("chat \(chat) had message updates, queueing a refresh") - let hashedThreadID = Hasher.thread.tokenizeRemembering(pii: guid) - return ServerEvent.refreshMessagesInThread(id: hashedThreadID) - } as [ServerEvent] +private struct PendingMessage { + let row: MappedMessageRow + let kind: PendingMessageKind +} + +private struct ThreadBatch { + let threadID: PlatformSDK.ThreadID + var upserts: [PlatformSDK.Message] = [] + var updates: [JSONObject] = [] + var deletes: [PlatformSDK.MessageID] = [] + var reactionUpsertsByMessageID = OrderedDictionary() + var reactionDeletesByMessageID = OrderedDictionary() } extension EventWatcher { @@ -33,31 +40,201 @@ extension EventWatcher { } func collectMessageUpdateEvents() throws -> [ServerEvent] { - let lastRowID = updatesCursor.lastRowID - let lastDateRead = updatesCursor.lastDateRead + let previousCursor = updatesCursor + let queryResult = try db.messages( + newerThanRowID: previousCursor.lastRowID, + orReadSince: previousCursor.lastDateRead, + orEditedSince: previousCursor.lastDateEdited + ) + traceMessageUpdates("updated messages query returned \(queryResult.updatedMessages.count) updated message(s)") - let queryResult = try db.chats(withMessagesNewerThanRowID: lastRowID, orReadSince: lastDateRead, orEditedSince: updatesCursor.lastDateEdited) - traceMessageUpdates("updated messages query returned \(queryResult.updatedChats.count) updated chat(s)") - guard !queryResult.updatedChats.isEmpty else { - traceMessageUpdates("no chats updated this time around") + let events = try messageUpdateEvents(for: queryResult) + let newCursor = MessageUpdatesCursor( + lastRowID: max(queryResult.latestMessageRowID ?? previousCursor.lastRowID, previousCursor.lastRowID), + lastDateRead: max(queryResult.latestMessageDateRead ?? previousCursor.lastDateRead, previousCursor.lastDateRead), + lastDateEdited: max(queryResult.latestDateEdited ?? previousCursor.lastDateEdited, previousCursor.lastDateEdited) + ) + traceMessageUpdates("done computing message state syncs, updating the messages updates cursor to: \(newCursor)") + updatesCursor = newCursor + return events + } + + private func messageUpdateEvents(for queryResult: UpdatedMessagesQueryResult) throws -> [ServerEvent] { + guard !queryResult.updatedMessages.isEmpty else { + traceMessageUpdates("no messages updated this time around") return [] } - guard let newLastRowID = queryResult.latestMessageRowID else { - log.error("didn't have new rowid cursor despite having updated chats? skipping updates") - return [] + + let msgRows = try db.mappedMessageRows(rowIDs: queryResult.updatedMessages.map(\.rowID)) + // `messageJoins` LEFT JOINs `chat_message_join`, so a message in multiple + // chats yields multiple rows with the same ROWID. Keep first. + let msgRowsByRowID = Dictionary(msgRows.map { ($0.rowID, $0) }, uniquingKeysWith: { first, _ in first }) + + var batchesByThreadID = [PlatformSDK.ThreadID: ThreadBatch]() + var pendingByThreadID = [PlatformSDK.ThreadID: OrderedDictionary]() + + for change in queryResult.updatedMessages { + guard let msgRow = msgRowsByRowID[change.rowID] else { + log.error("message update row \(change.rowID) couldn't be mapped, dropping") + continue + } + let threadID = msgRow.threadID ?? change.chatGUID + + if let associatedGUID = msgRow.associatedMessageGUID?.nonEmpty { + if let reaction = reaction(for: msgRow) { + let target = parseAssociatedMessageTarget(associatedGUID) + guard !target.messageID.isEmpty else { + log.error("message row \(msgRow.rowID) is a reaction but doesn't point at a message, dropping reaction state sync") + continue + } + + if change.isNew { + var batch = batchesByThreadID[threadID] ?? ThreadBatch(threadID: threadID) + switch reaction.action { + case .reacted: + if let messageReaction = mapMessageReaction(row: msgRow, reaction: reaction, currentUserID: currentUserID, accountID: accountID) { + batch.reactionUpsertsByMessageID[target.messageID, default: []].append(PlatformAPI.hashReaction(messageReaction)) + } else { + log.error("message row \(msgRow.rowID) is a reaction but couldn't be mapped, dropping reaction state sync") + } + pendingByThreadID[threadID, default: [:]][msgRow.rowID] = PendingMessage(row: msgRow, kind: .reactionAdd) + case .unreacted: + batch.reactionDeletesByMessageID[target.messageID, default: []].append( + PlatformAPI.hashedParticipantID(messageSenderID(for: msgRow, currentUserID: currentUserID)) + ) + if let replyToGUID = msgRow.replyToGUID { + batch.deletes.append(replyToGUID) + } + } + batchesByThreadID[threadID] = batch + } + + continue + } + + traceMessageUpdates("message row \(msgRow.rowID) is associated but not a reaction; treating as a message state sync") + } + + pendingByThreadID[threadID, default: [:]][msgRow.rowID] = PendingMessage(row: msgRow, kind: .normal(change)) + } + + let allPendingRows = pendingByThreadID.values.flatMap { $0.values.map(\.row) } + let mappedMessagesByRowID = try mapMessagesByRowID(allPendingRows) + + for (threadID, pendings) in pendingByThreadID { + var batch = batchesByThreadID[threadID] ?? ThreadBatch(threadID: threadID) + for pending in pendings.values { + let mappedMessages = mappedMessagesByRowID[pending.row.rowID] ?? [] + switch pending.kind { + case .reactionAdd: + batch.upserts.append(contentsOf: mappedMessages) + case .normal(let change): + if change.isNew { + batch.upserts.append(contentsOf: mappedMessages) + } + if let kind = MessageUpdateKind(change) { + batch.updates.append(contentsOf: mappedMessages.compactMap { kind.patch(for: $0) }) + } + } + } + batchesByThreadID[threadID] = batch } - defer { - let newCursor = MessageUpdatesCursor( - lastRowID: newLastRowID, - // Inherit the `lastDateRead` if it hasn't changed. - lastDateRead: queryResult.latestMessageDateRead ?? updatesCursor.lastDateRead, - lastDateEdited: queryResult.latestDateEdited ?? updatesCursor.lastDateEdited - ) - traceMessageUpdates("done computing refreshes, updating the messages updates cursor to: \(newCursor)") - updatesCursor = newCursor + for threadID in batchesByThreadID.keys { + guard var batch = batchesByThreadID[threadID], batch.updates.count > 1 else { continue } + batch.updates = deduplicatedUpdatePatches(batch.updates) + batchesByThreadID[threadID] = batch } - return try threadRefreshEvents(forUpdatedChats: queryResult) + return stateSyncEvents(batches: batchesByThreadID.values) + } + + enum MessageUpdateKind { + case edited, read + + init?(_ change: UpdatedMessageChange) { + // Edits dominate read receipts: a same-tick edit+read becomes a full-message patch. + if change.wasEdited { self = .edited } + else if change.wasRead { self = .read } + else { return nil } + } + + func patch(for message: PlatformSDK.Message) -> JSONObject? { + switch self { + case .edited: + return message.jsonObject + case .read: + var patch = compactDictionary([ + "seen": message.seen?.jsonValue, + "behavior": message.behavior?.rawValue, + "isDelivered": message.isDelivered, + "isErrored": message.isErrored, + ]) + guard !patch.isEmpty else { return nil } + patch["id"] = message.id + return patch + } + } + } + + private func mapMessagesByRowID(_ msgRows: [MappedMessageRow]) throws -> [Int: [PlatformSDK.Message]] { + try PlatformAPI.mapAndHashMessagesByRowID( + db: db, + msgRows: msgRows, + threadID: "", + currentUserID: currentUserID, + accountID: accountID + ) + } + + private func stateSyncEvents(batches: Dictionary.Values) -> [ServerEvent] { + var events = [ServerEvent]() + // Per-thread emit order keeps creates before updates and deletes after + // both; reaction events are scoped to their target message via + // objectIDs.messageID. + for batch in batches { + guard !batch.upserts.isEmpty || + !batch.updates.isEmpty || + !batch.deletes.isEmpty || + !batch.reactionUpsertsByMessageID.isEmpty || + !batch.reactionDeletesByMessageID.isEmpty else { continue } + let hashedThreadID = Hasher.thread.tokenizeRemembering(pii: batch.threadID) + for (messageID, reactions) in batch.reactionUpsertsByMessageID where !reactions.isEmpty { + events.append(.upsertMessageReactions(threadID: hashedThreadID, messageID: messageID, reactions: reactions)) + } + if !batch.upserts.isEmpty { + events.append(.upsertMessages(threadID: hashedThreadID, messages: batch.upserts)) + } + if !batch.updates.isEmpty { + events.append(.updateMessages(threadID: hashedThreadID, patches: batch.updates)) + } + for (messageID, ids) in batch.reactionDeletesByMessageID where !ids.isEmpty { + events.append(.deleteMessageReactions(threadID: hashedThreadID, messageID: messageID, ids: ids)) + } + if !batch.deletes.isEmpty { + events.append(.deleteMessages(threadID: hashedThreadID, ids: batch.deletes)) + } + } + return events + } + + private func deduplicatedUpdatePatches(_ patches: [JSONObject]) -> [JSONObject] { + guard patches.count > 1 else { return patches } + // OrderedDictionary preserves first-seen patch order so identical inputs produce + // identical event sequences; plain `Dictionary` value order is undefined. + var patchesByID = OrderedDictionary() + for patch in patches { + guard let id = patch["id"] as? String else { continue } + patchesByID[id, default: [:]].merge(patch) { _, new in new } + } + return Array(patchesByID.values) + } + + private func reaction(for msgRow: MappedMessageRow) -> AssociatedReaction? { + guard let associatedMessageType = associatedMessageTypes[msgRow.associatedMessageType], + case let .reaction(reaction) = associatedMessageType else { + return nil + } + return reaction } } diff --git a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher.swift b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher.swift index dc45a874..a0e72013 100644 --- a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher.swift +++ b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcher.swift @@ -19,15 +19,18 @@ final class EventWatcher { var db: IMDatabase /// Tracks the last known state of every chat. - var chatStates = [ChatRef: TimestampedChatState]() + var chatStates = [String: TimestampedChatState]() var updatesCursor: MessageUpdatesCursor + let currentUserID: String + let accountID: String private var sender: PlatformAPI.EventCallback private let reportErrorMessage: PlatformAPI.ReportErrorMessage? init( serverEventSender sender: @escaping PlatformAPI.EventCallback, initialUpdatesCursor: MessageUpdatesCursor, + accountID: String, reportErrorMessage: PlatformAPI.ReportErrorMessage? = nil ) throws { self.db = try IMDatabase() @@ -37,6 +40,8 @@ final class EventWatcher { } self.sender = sender self.updatesCursor = initialUpdatesCursor + self.currentUserID = try PlatformSDK.CurrentUser.fetch(from: db).id + self.accountID = accountID self.reportErrorMessage = reportErrorMessage } diff --git a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcherLifecycle.swift b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcherLifecycle.swift index 31b39ca5..20d8e40a 100644 --- a/src/IMessage/Sources/IMessage/EventWatcher/EventWatcherLifecycle.swift +++ b/src/IMessage/Sources/IMessage/EventWatcher/EventWatcherLifecycle.swift @@ -8,10 +8,15 @@ private let eventWatchingLog = Logger(imessageLabel: "event-watcher-lifecycle") final class EventWatcherLifecycle { static let shared = EventWatcherLifecycle() + private struct Subscription { + var onEvent: PlatformAPI.EventCallback + var reportErrorMessage: PlatformAPI.ReportErrorMessage? + var accountID: String + } + private struct State { - var onEvent: PlatformAPI.EventCallback? + var subscription: Subscription? var watchingTask: Task? - var reportErrorMessage: PlatformAPI.ReportErrorMessage? } private let state = Protected(State()) @@ -22,10 +27,17 @@ final class EventWatcherLifecycle { state.withLock { $0.watchingTask != nil } } - func subscribeToEvents(_ onEvent: @escaping PlatformAPI.EventCallback, reportErrorMessage: PlatformAPI.ReportErrorMessage? = nil) { + func subscribeToEvents( + _ onEvent: @escaping PlatformAPI.EventCallback, + accountID: String, + reportErrorMessage: PlatformAPI.ReportErrorMessage? = nil + ) { state.withLock { state in - state.onEvent = onEvent - state.reportErrorMessage = reportErrorMessage + state.subscription = Subscription( + onEvent: onEvent, + reportErrorMessage: reportErrorMessage, + accountID: accountID + ) } } @@ -34,8 +46,7 @@ final class EventWatcherLifecycle { let watchingTask = state.watchingTask state.watchingTask = nil if clearEventCallback { - state.onEvent = nil - state.reportErrorMessage = nil + state.subscription = nil } return watchingTask } @@ -51,22 +62,24 @@ final class EventWatcherLifecycle { } } - func startEventWatchingFromCurrentState(lastRowID: Int, lastDateRead: Date) throws { - guard let onEvent = state.withLock({ $0.onEvent }) else { + func startEventWatchingFromCurrentState(lastRowID: Int, lastDateRead: Date, lastDateEdited: Date) throws { + guard let subscription = state.withLock({ $0.subscription }) else { throw ErrorMessage("subscribeToEvents must be called before startEventWatchingFromCurrentState") } try startWatching( - onEvent: onEvent, + subscription: subscription, lastRowID: lastRowID, lastDateRead: lastDateRead, + lastDateEdited: lastDateEdited, source: "current state" ) } - func startWatching( - onEvent: @escaping PlatformAPI.EventCallback, + private func startWatching( + subscription: Subscription, lastRowID: Int, lastDateRead: Date, + lastDateEdited: Date, source: String ) throws { let existingTask = state.withLock { state in @@ -79,19 +92,18 @@ final class EventWatcherLifecycle { existingTask.cancel() } - eventWatchingLog.debug("starting event watcher from \(source) (last row id: \(lastRowID), last date read: \(lastDateRead))") - - let reportErrorMessage = state.withLock { $0.reportErrorMessage } + eventWatchingLog.debug("starting event watcher from \(source) (last row id: \(lastRowID), last date read: \(lastDateRead), last date edited: \(lastDateEdited))") let eventWatcher = try EventWatcher( serverEventSender: { events in #if DEBUG eventWatchingLog.debug("handing over \(events.count) value(s) to the event callback") #endif - try await onEvent(events) + try await subscription.onEvent(events) }, - initialUpdatesCursor: EventWatcher.MessageUpdatesCursor(lastRowID: lastRowID, lastDateRead: lastDateRead, lastDateEdited: Date()), - reportErrorMessage: reportErrorMessage + initialUpdatesCursor: EventWatcher.MessageUpdatesCursor(lastRowID: lastRowID, lastDateRead: lastDateRead, lastDateEdited: lastDateEdited), + accountID: subscription.accountID, + reportErrorMessage: subscription.reportErrorMessage ) let watchingTask = Task { @@ -100,7 +112,7 @@ final class EventWatcherLifecycle { try await eventWatcher.watchForever() } catch { eventWatchingLog.error("event watcher died: \(String(reflecting: error))") - try? reportErrorMessage?("imsg event watcher died: \(String(reflecting: error))") + try? subscription.reportErrorMessage?("imsg event watcher died: \(String(reflecting: error))") } } diff --git a/src/IMessage/Sources/IMessage/Hashing/PlatformAPI+Hashing.swift b/src/IMessage/Sources/IMessage/Hashing/PlatformAPI+Hashing.swift index 9201f273..1ac1f28d 100644 --- a/src/IMessage/Sources/IMessage/Hashing/PlatformAPI+Hashing.swift +++ b/src/IMessage/Sources/IMessage/Hashing/PlatformAPI+Hashing.swift @@ -37,15 +37,35 @@ extension PlatformAPI { currentUserID: String, accountID: String ) throws -> [PlatformSDK.Message] { + let messagesByRowID = try mapAndHashMessagesByRowID( + msgRows: msgRows, + attachmentRows: attachmentRows, + reactionRows: reactionRows, + currentUserID: currentUserID, + accountID: accountID + ) + return msgRows.flatMap { msgRow in + messagesByRowID[msgRow.rowID] ?? [] + } + } + + nonisolated static func mapAndHashMessagesByRowID( + msgRows: [MappedMessageRow], + attachmentRows: [MappedAttachmentRow], + reactionRows: [MappedReactionMessageRow], + currentUserID: String, + accountID: String + ) throws -> [Int: [PlatformSDK.Message]] { guard !msgRows.isEmpty else { - return [] + return [:] } let attachmentRowsByMessageID = Dictionary(grouping: attachmentRows, by: \.msgRowID) let reactionRowsByMessageGUID = Dictionary(grouping: reactionRows, by: { reactionMessageGUID($0.associatedMessageGUID) }) - return try msgRows.flatMap { msgRow -> [PlatformSDK.Message] in - try mapAndHashMessage( + var messagesByRowID = [Int: [PlatformSDK.Message]]() + for msgRow in msgRows { + messagesByRowID[msgRow.rowID] = try mapAndHashMessage( msgRow: msgRow, attachmentRows: attachmentRowsByMessageID[msgRow.rowID] ?? [], reactionRows: reactionRowsByMessageGUID[msgRow.guid] ?? [], @@ -53,6 +73,7 @@ extension PlatformAPI { accountID: accountID ) } + return messagesByRowID } nonisolated static func mapAndHashMessage( @@ -77,19 +98,21 @@ extension PlatformAPI { copyMessage( message, senderID: Hasher.participant.tokenizeRemembering(pii: message.senderID), - reactions: message.reactions?.map { reaction in - PlatformSDK.MessageReaction( - id: Hasher.participant.tokenizeRemembering(pii: reaction.id), - reactionKey: reaction.reactionKey, - imgURL: reaction.imgURL, - participantID: Hasher.participant.tokenizeRemembering(pii: reaction.participantID), - emoji: reaction.emoji - ) - }, + reactions: message.reactions?.map(hashReaction), threadID: message.threadID.map { Hasher.thread.tokenizeRemembering(pii: $0) } ) } + nonisolated static func hashReaction(_ reaction: PlatformSDK.MessageReaction) -> PlatformSDK.MessageReaction { + PlatformSDK.MessageReaction( + id: Hasher.participant.tokenizeRemembering(pii: reaction.id), + reactionKey: reaction.reactionKey, + imgURL: reaction.imgURL, + participantID: Hasher.participant.tokenizeRemembering(pii: reaction.participantID), + emoji: reaction.emoji + ) + } + nonisolated static func copyMessage( _ message: PlatformSDK.Message, senderID: PlatformSDK.UserID? = nil, diff --git a/src/IMessage/Sources/IMessage/Mappers/MessageMapper+Associated.swift b/src/IMessage/Sources/IMessage/Mappers/MessageMapper+Associated.swift index 185c79a0..4c556695 100644 --- a/src/IMessage/Sources/IMessage/Mappers/MessageMapper+Associated.swift +++ b/src/IMessage/Sources/IMessage/Mappers/MessageMapper+Associated.swift @@ -23,24 +23,19 @@ extension Mapper { ) -> MessageDraft? { let firstTextPart = messages.first { $0.text != nil } var message = firstTextPart ?? partialMessage - let guidRange = NSRange(associatedGUID.startIndex ..< associatedGUID.endIndex, in: associatedGUID) - let linkedMessageID = assocMsgGUIDPrefixRegex.stringByReplacingMatches( - in: associatedGUID, - range: guidRange, - withTemplate: "" - ) + let linkedMessageID = parseAssociatedMessageTarget(associatedGUID).messageID message.linkedMessageID = linkedMessageID - guard let assocMsgType = associatedMessageTypes[msgRow.associatedMessageType] else { + guard let associatedMessageType = associatedMessageTypes[msgRow.associatedMessageType] else { return nil } - switch assocMsgType { - case "sticker": + switch associatedMessageType { + case .sticker: if !messages.isEmpty { messages[0].linkedMessageID = linkedMessageID } return nil - case "heading": + case .heading: if var text = message.text { let other = msgRow.participantID ?? "" let isSender = message.isSender == true @@ -53,9 +48,9 @@ extension Mapper { } message.parseTemplate = true return message - default: + case let .reaction(reaction): return mapReactionAction( - assocMsgType: assocMsgType, + reaction: reaction, message: message, summaryInfo: summaryInfo, isSMS: isSMS @@ -72,20 +67,15 @@ extension Mapper { return reaction.associatedMessageGUID.hasPrefix("p:\(filterIndex)/") } for reaction in filteredRows { - guard let assocMsgType = associatedMessageTypes[reaction.associatedMessageType], - let parts = reactionParts(assocMsgType), - assocMsgType != "sticker" else { + guard let associatedMessageType = associatedMessageTypes[reaction.associatedMessageType], + case let .reaction(parts) = associatedMessageType else { continue } - let participantID = senderID(for: reaction) - if parts.actionType == "reacted" { - reactions.append(PlatformSDK.MessageReaction( - id: participantID, - reactionKey: parts.actionKey == "emoji" ? (reaction.associatedMessageEmoji ?? "") : parts.actionKey, - imgURL: parts.actionKey == "sticker" ? reactionStickerAssetURL(rowID: reaction.rowID) : nil, - participantID: participantID - )) - } else if parts.actionType == "unreacted", let index = reactions.firstIndex(where: { $0.id == participantID }) { + if parts.action == .reacted { + if let messageReaction = mapMessageReaction(row: reaction, reaction: parts, currentUserID: currentUserID, accountID: accountID) { + reactions.append(messageReaction) + } + } else if parts.action == .unreacted, let index = reactions.firstIndex(where: { $0.id == messageSenderID(for: reaction, currentUserID: currentUserID) }) { reactions.remove(at: index) } } @@ -102,67 +92,92 @@ extension Mapper { } func senderID() -> String { - senderID(for: msgRow) + messageSenderID(for: msgRow, currentUserID: currentUserID) } func reactionStickerAssetURL(rowID: Int) -> String { - "asset://\(accountID)/reaction-sticker/\(rowID).heic" + reactionStickerAssetURLString(accountID: accountID, rowID: rowID) } private func mapReactionAction( - assocMsgType: String, + reaction: AssociatedReaction, message inputMessage: MessageDraft, summaryInfo: JSONObject, isSMS: Bool ) -> MessageDraft { var message = inputMessage - guard let parts = reactionParts(assocMsgType) else { - return message - } - guard parts.actionType == "reacted" || parts.actionType == "unreacted" else { - return message - } message.isAction = !isSMS let action = PlatformSDK.PartialMessageReactionAction( messageID: message.linkedMessageID, - reactionKey: parts.actionKey == "emoji" ? msgRow.associatedMessageEmoji : parts.actionKey, - imgURL: assocMsgType == "reacted_sticker" ? reactionStickerAssetURL(rowID: msgRow.rowID) : nil, + reactionKey: reaction.platformReactionKey(emoji: msgRow.associatedMessageEmoji), + imgURL: reaction.includesStickerAssetInAction ? reactionStickerAssetURL(rowID: msgRow.rowID) : nil, participantID: message.senderID ) - message.action = parts.actionType == "reacted" + message.action = reaction.action == .reacted ? .messageReactionCreated(action) : .messageReactionDeleted(action) - if parts.actionKey == "emoji" || parts.actionKey == "sticker" || supportedReactionKeys.contains(parts.actionKey) { - message.parseTemplate = true - let actor = msgRow.isFromMe == 1 ? "You" : "{{sender}}" - let target = summaryInfo.string("ams").flatMap { $0.isEmpty ? nil : $0 }.map { "\"\($0)\"" } ?? "a message" - message.text = "\(actor) \(reactionVerbMap[assocMsgType] ?? "") \(target)" - message.isHidden = true - } + message.parseTemplate = true + let actor = msgRow.isFromMe == 1 ? "You" : "{{sender}}" + let target = summaryInfo.string("ams").flatMap { $0.isEmpty ? nil : $0 }.map { "\"\($0)\"" } ?? "a message" + message.text = "\(actor) \(reaction.verb) \(target)" + message.isHidden = true return message } - private func reactionParts(_ assocMsgType: String) -> (actionType: String, actionKey: String)? { - let pieces = assocMsgType.components(separatedBy: "_") - guard pieces.count == 2 else { - return nil - } - return (pieces[0], pieces[1]) - } - - private func senderID(for row: any RowWithSenderFields) -> String { - if row.isFromMe == 1 || ((row.participantID ?? "").isEmpty && row.handleID == 0) { - return currentUserID - } - return row.participantID ?? "" - } } -private protocol RowWithSenderFields { +protocol RowWithSenderFields { var isFromMe: Int { get } var handleID: Int? { get } var participantID: String? { get } } +protocol MessageReactionRowFields: RowWithSenderFields { + var rowID: Int { get } + var associatedMessageType: Int { get } + var associatedMessageEmoji: String? { get } +} + extension MappedMessageRow: RowWithSenderFields {} extension MappedReactionMessageRow: RowWithSenderFields {} +extension MappedMessageRow: MessageReactionRowFields {} +extension MappedReactionMessageRow: MessageReactionRowFields {} + +func messageSenderID(for row: any RowWithSenderFields, currentUserID: String) -> String { + if row.isFromMe == 1 || ((row.participantID ?? "").isEmpty && row.handleID == 0) { + return currentUserID + } + return row.participantID ?? "" +} + +func reactionStickerAssetURLString(accountID: String, rowID: Int) -> String { + "asset://\(accountID)/reaction-sticker/\(rowID).heic" +} + +func mapMessageReaction( + row: any MessageReactionRowFields, + currentUserID: String, + accountID: String +) -> PlatformSDK.MessageReaction? { + guard let associatedMessageType = associatedMessageTypes[row.associatedMessageType], + case let .reaction(reaction) = associatedMessageType else { + return nil + } + return mapMessageReaction(row: row, reaction: reaction, currentUserID: currentUserID, accountID: accountID) +} + +func mapMessageReaction( + row: any MessageReactionRowFields, + reaction: AssociatedReaction, + currentUserID: String, + accountID: String +) -> PlatformSDK.MessageReaction? { + let reactionKey = reaction.platformReactionKey(emoji: row.associatedMessageEmoji) ?? "" + let participantID = messageSenderID(for: row, currentUserID: currentUserID) + return PlatformSDK.MessageReaction( + id: participantID, + reactionKey: reactionKey, + imgURL: reaction.includesStickerAssetInAction ? reactionStickerAssetURLString(accountID: accountID, rowID: row.rowID) : nil, + participantID: participantID + ) +} diff --git a/src/IMessage/Sources/IMessage/Mappers/MessageMapperTypes.swift b/src/IMessage/Sources/IMessage/Mappers/MessageMapperTypes.swift index 0b8531d6..06a7bde6 100644 --- a/src/IMessage/Sources/IMessage/Mappers/MessageMapperTypes.swift +++ b/src/IMessage/Sources/IMessage/Mappers/MessageMapperTypes.swift @@ -8,6 +8,96 @@ let uuidStart = 11 let uuidLength = 36 let coreFoundationReferenceDateMilliseconds: Int64 = 978_307_200_000 +struct AssociatedMessageTarget: Hashable { + let part: String? + let messageGUID: String + + var messageID: PlatformSDK.MessageID { + if let part, part != "0" { + return "\(messageGUID)_\(part)" + } + return messageGUID + } +} + +func parseAssociatedMessageTarget(_ associatedMessageGUID: String) -> AssociatedMessageTarget { + let range = NSRange(associatedMessageGUID.startIndex ..< associatedMessageGUID.endIndex, in: associatedMessageGUID) + guard let match = assocMsgGUIDPrefixRegex.firstMatch(in: associatedMessageGUID, range: range), + let upper = Range(match.range, in: associatedMessageGUID)?.upperBound else { + return AssociatedMessageTarget(part: nil, messageGUID: associatedMessageGUID) + } + + let part = Range(match.range(at: 1), in: associatedMessageGUID).map { String(associatedMessageGUID[$0]) } + let rawMessageGUID = String(associatedMessageGUID[upper...]) + let messageGUID = rawMessageGUID.hasPrefix("bp:") ? String(rawMessageGUID.dropFirst(3)) : rawMessageGUID + return AssociatedMessageTarget(part: part, messageGUID: messageGUID) +} + +enum ReactionAction { + case reacted + case unreacted +} + +enum AssociatedReactionKey: String { + case heart + case like + case dislike + case laugh + case emphasize + case question + case emoji + case sticker +} + +struct AssociatedReaction { + let action: ReactionAction + let key: AssociatedReactionKey + + func platformReactionKey(emoji: String?) -> String? { + switch key { + case .emoji: + return emoji + default: + return key.rawValue + } + } + + var isSticker: Bool { + key == .sticker + } + + var includesStickerAssetInAction: Bool { + action == .reacted && isSticker + } + + var verb: String { + switch (action, key) { + case (.reacted, .heart): return "loved" + case (.reacted, .like): return "liked" + case (.reacted, .dislike): return "disliked" + case (.reacted, .laugh): return "laughed at" + case (.reacted, .emphasize): return "emphasized" + case (.reacted, .question): return "questioned" + case (.reacted, .emoji): return "reacted to" + case (.reacted, .sticker): return "reacted with a sticker to" + case (.unreacted, .heart): return "removed a heart from" + case (.unreacted, .like): return "removed a like from" + case (.unreacted, .dislike): return "removed a dislike from" + case (.unreacted, .laugh): return "removed a laugh from" + case (.unreacted, .emphasize): return "removed an exclamation from" + case (.unreacted, .question): return "removed a question mark from" + case (.unreacted, .emoji): return "unreacted from" + case (.unreacted, .sticker): return "removed a sticker from" + } + } +} + +enum AssociatedMessageType { + case heading + case sticker + case reaction(AssociatedReaction) +} + enum MessagePart { case text(index: Int, end: Int, text: String, attributes: PlatformSDK.TextAttributes?) case attachment(index: Int, end: Int, attachmentID: String) @@ -67,48 +157,31 @@ let videoExtensions: Set = [ "svi", "vob", "webm", "wmv", "yuv", ] -let associatedMessageTypes: [Int: String] = [ - 3: "heading", - 1000: "sticker", - 2000: "reacted_heart", - 2001: "reacted_like", - 2002: "reacted_dislike", - 2003: "reacted_laugh", - 2004: "reacted_emphasize", - 2005: "reacted_question", - 2006: "reacted_emoji", - 2007: "reacted_sticker", - 3000: "unreacted_heart", - 3001: "unreacted_like", - 3002: "unreacted_dislike", - 3003: "unreacted_laugh", - 3004: "unreacted_emphasize", - 3005: "unreacted_question", - 3006: "unreacted_emoji", - 3007: "unreacted_sticker", -] +private func associatedReaction(_ action: ReactionAction, _ key: AssociatedReactionKey) -> AssociatedMessageType { + .reaction(AssociatedReaction(action: action, key: key)) +} -let reactionVerbMap = [ - "reacted_heart": "loved", - "reacted_like": "liked", - "reacted_dislike": "disliked", - "reacted_laugh": "laughed at", - "reacted_emphasize": "emphasized", - "reacted_question": "questioned", - "reacted_emoji": "reacted to", - "reacted_sticker": "reacted with a sticker to", - "unreacted_heart": "removed a heart from", - "unreacted_like": "removed a like from", - "unreacted_dislike": "removed a dislike from", - "unreacted_laugh": "removed a laugh from", - "unreacted_emphasize": "removed an exclamation from", - "unreacted_question": "removed a question mark from", - "unreacted_emoji": "unreacted from", - "unreacted_sticker": "removed a sticker from", +let associatedMessageTypes: [Int: AssociatedMessageType] = [ + 3: .heading, + 1000: .sticker, + 2000: associatedReaction(.reacted, .heart), + 2001: associatedReaction(.reacted, .like), + 2002: associatedReaction(.reacted, .dislike), + 2003: associatedReaction(.reacted, .laugh), + 2004: associatedReaction(.reacted, .emphasize), + 2005: associatedReaction(.reacted, .question), + 2006: associatedReaction(.reacted, .emoji), + 2007: associatedReaction(.reacted, .sticker), + 3000: associatedReaction(.unreacted, .heart), + 3001: associatedReaction(.unreacted, .like), + 3002: associatedReaction(.unreacted, .dislike), + 3003: associatedReaction(.unreacted, .laugh), + 3004: associatedReaction(.unreacted, .emphasize), + 3005: associatedReaction(.unreacted, .question), + 3006: associatedReaction(.unreacted, .emoji), + 3007: associatedReaction(.unreacted, .sticker), ] -let supportedReactionKeys: Set = ["heart", "like", "dislike", "laugh", "emphasize", "question"] - let expressiveMessages = [ "com.apple.messages.effect.CKEchoEffect": "Echo screen", "com.apple.messages.effect.CKSpotlightEffect": "Spotlight screen", diff --git a/src/IMessage/Sources/IMessage/PlatformAPI.swift b/src/IMessage/Sources/IMessage/PlatformAPI.swift index c282c611..2aead388 100644 --- a/src/IMessage/Sources/IMessage/PlatformAPI.swift +++ b/src/IMessage/Sources/IMessage/PlatformAPI.swift @@ -117,19 +117,24 @@ public final class PlatformAPI { } public func subscribeToEvents(_ onEvent: @escaping EventCallback) { - EventWatcherLifecycle.shared.subscribeToEvents(onEvent, reportErrorMessage: errorMessageReporter) + EventWatcherLifecycle.shared.subscribeToEvents( + onEvent, + accountID: accountID, + reportErrorMessage: errorMessageReporter + ) } public func startEventWatchingFromCurrentState() async throws { let database = database - let (lastRowID, lastDateRead) = try await Task.detached(priority: .userInitiated) { + let cursorSnapshot = try await Task.detached(priority: .userInitiated) { try database.withDatabase { db in - (try db.lastMessageRowID(), try db.maxMessageDateRead()) + try db.messageUpdateCursorSnapshot() } }.value try EventWatcherLifecycle.shared.startEventWatchingFromCurrentState( - lastRowID: lastRowID, - lastDateRead: lastDateRead + lastRowID: cursorSnapshot.lastRowID, + lastDateRead: cursorSnapshot.lastDateRead, + lastDateEdited: cursorSnapshot.lastDateEdited ) } @@ -920,18 +925,17 @@ extension PlatformAPI { ) throws -> [String: [PlatformSDK.Message]] { let msgRows = Array(latestMessageRowsByChatGUID.values) let payloadRows = try messagePayloadRows(db: db, msgRows: msgRows, threadID: "") - let attachmentRowsByMessageID = Dictionary(grouping: payloadRows.attachmentRows, by: \.msgRowID) - let reactionRowsByMessageGUID = Dictionary(grouping: payloadRows.reactionRows, by: { reactionMessageGUID($0.associatedMessageGUID) }) + let messagesByRowID = try mapAndHashMessagesByRowID( + msgRows: msgRows, + attachmentRows: payloadRows.attachmentRows, + reactionRows: payloadRows.reactionRows, + currentUserID: currentUserID, + accountID: accountID + ) var latestMessagesByChatGUID = [String: [PlatformSDK.Message]]() for (guid, msgRow) in latestMessageRowsByChatGUID { - latestMessagesByChatGUID[guid] = try mapAndHashMessage( - msgRow: msgRow, - attachmentRows: attachmentRowsByMessageID[msgRow.rowID] ?? [], - reactionRows: reactionRowsByMessageGUID[msgRow.guid] ?? [], - currentUserID: currentUserID, - accountID: accountID - ) + latestMessagesByChatGUID[guid] = messagesByRowID[msgRow.rowID] ?? [] } return latestMessagesByChatGUID } @@ -957,6 +961,23 @@ extension PlatformAPI { return fileURL.path } + nonisolated static func mapAndHashMessagesByRowID( + db: IMDatabase, + msgRows: [MappedMessageRow], + threadID: String, + currentUserID: String, + accountID: String + ) throws -> [Int: [PlatformSDK.Message]] { + let payloadRows = try messagePayloadRows(db: db, msgRows: msgRows, threadID: threadID) + return try mapAndHashMessagesByRowID( + msgRows: msgRows, + attachmentRows: payloadRows.attachmentRows, + reactionRows: payloadRows.reactionRows, + currentUserID: currentUserID, + accountID: accountID + ) + } + nonisolated private static func messagePayloadRows( db: IMDatabase, msgRows: [MappedMessageRow], @@ -1019,12 +1040,7 @@ extension PlatformAPI { } nonisolated static func reactionMessageGUID(_ associatedMessageGUID: String) -> String { - let range = NSRange(associatedMessageGUID.startIndex ..< associatedMessageGUID.endIndex, in: associatedMessageGUID) - guard let match = assocMsgGUIDPrefixRegex.firstMatch(in: associatedMessageGUID, range: range), - let upper = Range(match.range, in: associatedMessageGUID)?.upperBound else { - return associatedMessageGUID - } - return String(associatedMessageGUID[upper...]) + parseAssociatedMessageTarget(associatedMessageGUID).messageGUID } nonisolated private static func getAsset(db database: PlatformAPIDatabase, pathHex: String, methodName: String) async throws -> AssetResult { diff --git a/src/IMessage/Sources/IMessageCore/Array+Chunks.swift b/src/IMessage/Sources/IMessageCore/Array+Chunks.swift new file mode 100644 index 00000000..18d72ea5 --- /dev/null +++ b/src/IMessage/Sources/IMessageCore/Array+Chunks.swift @@ -0,0 +1,8 @@ +public extension Array { + package func chunks(ofCount size: Int) -> [ArraySlice] { + guard size > 0 else { return [] } + return stride(from: startIndex, to: endIndex, by: size).map { start in + self[start ..< Swift.min(start + size, endIndex)] + } + } +} diff --git a/src/IMessage/Sources/IMessagePerfBench/README.md b/src/IMessage/Sources/IMessagePerfBench/README.md new file mode 100644 index 00000000..69a13ec0 --- /dev/null +++ b/src/IMessage/Sources/IMessagePerfBench/README.md @@ -0,0 +1,23 @@ +# IMessagePerfBench + +Backend-agnostic performance harness for the iMessage read paths. + +The benchmark intentionally calls public `IMDatabase` methods and public +`PlatformAPI` methods. It does not import GRDB, SQLiteData, or any other +storage implementation directly, so a branch can swap the internals behind +`IMDatabase` and keep using the same benchmark. + +Run through the repo wrapper for terminal tables: + +```sh +yarn perf:imessage +``` + +Useful variants: + +```sh +yarn perf:imessage --sql-only --iterations 20 +yarn perf:imessage --api-only --api-thread-samples 10 +yarn perf:imessage --with-parity --max-chats 5 --message-limit 20 +yarn perf:imessage --json +``` diff --git a/src/IMessage/Sources/IMessagePerfBench/main.swift b/src/IMessage/Sources/IMessagePerfBench/main.swift new file mode 100644 index 00000000..1fe8556e --- /dev/null +++ b/src/IMessage/Sources/IMessagePerfBench/main.swift @@ -0,0 +1,407 @@ +import ArgumentParser +import Foundation +import IMDatabase +import IMessage + +enum BenchmarkFormat: String, ExpressibleByArgument { + case json + case pretty + + init?(argument: String) { + self.init(rawValue: argument) + } +} + +struct BenchmarkMetadata: Encodable { + let messagesDir: String + let iterations: Int + let warmups: Int + let maxChats: Int + let messageLimit: Int + let apiThreadSamples: Int + let searchQuery: String + let createIndexes: Bool + let sqlIncluded: Bool + let apiIncluded: Bool +} + +struct BenchmarkSection: Encodable { + let skipped: Bool + let results: [BenchmarkResult] +} + +struct BenchmarkResult: Encodable { + let name: String + let resultCount: Int + let iterations: Int + let warmups: Int + let samplesMS: [Double] + let averageMS: Double + let p50MS: Double + let p95MS: Double + let minMS: Double + let maxMS: Double +} + +struct BenchmarkReport: Encodable { + let metadata: BenchmarkMetadata + let sql: BenchmarkSection + let api: BenchmarkSection +} + +struct SQLSample { + let threadRows: [MappedChatRow] + let messageChatGUIDs: [String] + let messageRows: [MappedMessageRow] + + var chatRowIDs: [Int] { + threadRows.map(\.rowID) + } + + var messageRowIDs: [Int] { + messageRows.map(\.rowID) + } + + var messageGUIDs: [String] { + messageRows.map(\.guid) + } + + var messageChatRowIDs: [Int] { + Array(Set(messageRows.compactMap(\.chatRowID))).sorted() + } +} + +enum BenchError: Error, CustomStringConvertible { + case noThreads + case invalidOption(String) + + var description: String { + switch self { + case .noThreads: + return "No iMessage threads were found in the selected Messages database." + case let .invalidOption(message): + return message + } + } +} + +@main +struct IMessagePerfBench: AsyncParsableCommand { + static let configuration = CommandConfiguration( + abstract: "Benchmark iMessage database and API read paths without depending on a specific SQL backend." + ) + + @Option(help: "Messages data directory. Defaults to ~/Library/Messages.") + var messagesDir: String = "~/Library/Messages" + + @Option(help: "Measured iterations per benchmark case.") + var iterations: Int = 7 + + @Option(help: "Warmup iterations per benchmark case.") + var warmups: Int = 2 + + @Option(help: "Maximum chats to sample for SQL benchmarks.") + var maxChats: Int = 10 + + @Option(help: "Maximum messages per sampled chat.") + var messageLimit: Int = 50 + + @Option(help: "Maximum threads to sample for PlatformAPI.getMessages.") + var apiThreadSamples: Int = 5 + + @Option(help: "Search text for searchMessages benchmarks.") + var searchQuery: String = "a" + + @Flag(help: "Ask IMDatabase to create its optional read indexes before benchmarking.") + var createIndexes = false + + @Flag(help: "Only run IMDatabase SQL hot path benchmarks.") + var sqlOnly = false + + @Flag(help: "Only run final PlatformAPI.getThreads/getMessages benchmarks.") + var apiOnly = false + + @Option(help: "Output format.") + var format: BenchmarkFormat = .json + + mutating func run() async throws { + try validateOptions() + + let messagesURL = expandTilde(in: messagesDir) + let includeSQL = !apiOnly + let includeAPI = !sqlOnly + + let sqlResults = includeSQL + ? try runSQLBenchmarks(messagesURL: messagesURL) + : [] + let apiResults = includeAPI + ? try await runAPIBenchmarks() + : [] + + let report = BenchmarkReport( + metadata: BenchmarkMetadata( + messagesDir: messagesURL.path, + iterations: iterations, + warmups: warmups, + maxChats: maxChats, + messageLimit: messageLimit, + apiThreadSamples: apiThreadSamples, + searchQuery: searchQuery, + createIndexes: createIndexes, + sqlIncluded: includeSQL, + apiIncluded: includeAPI + ), + sql: BenchmarkSection(skipped: !includeSQL, results: sqlResults), + api: BenchmarkSection(skipped: !includeAPI, results: apiResults) + ) + + switch format { + case .json: + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + FileHandle.standardOutput.write(try encoder.encode(report)) + FileHandle.standardOutput.write(Data("\n".utf8)) + case .pretty: + printPretty(report) + } + } + + private func validateOptions() throws { + if iterations <= 0 { + throw BenchError.invalidOption("--iterations must be greater than zero.") + } + if warmups < 0 { + throw BenchError.invalidOption("--warmups must be zero or greater.") + } + if maxChats <= 0 { + throw BenchError.invalidOption("--max-chats must be greater than zero.") + } + if messageLimit <= 0 { + throw BenchError.invalidOption("--message-limit must be greater than zero.") + } + if apiThreadSamples <= 0 { + throw BenchError.invalidOption("--api-thread-samples must be greater than zero.") + } + if sqlOnly && apiOnly { + throw BenchError.invalidOption("--sql-only and --api-only cannot both be set.") + } + } + + private func runSQLBenchmarks(messagesURL: URL) throws -> [BenchmarkResult] { + let db = try IMDatabase(messagesDataBaseURL: messagesURL, createIndexes: createIndexes) + let sample = try makeSQLSample(db: db) + var results: [BenchmarkResult] = [] + + results.append(try measure("mappedThreadRows") { + try db.mappedThreadRows(cursor: nil, direction: nil, limit: maxChats).count + }) + results.append(try measure("mappedLatestMessageRows") { + try db.mappedLatestMessageRows(chatRowIDs: sample.chatRowIDs).count + }) + results.append(try measure("mappedThreadParticipantRows") { + try db.mappedThreadParticipantRows(chatRowIDs: sample.chatRowIDs).values.reduce(0) { $0 + $1.count } + }) + results.append(try measure("mappedUnreadCounts") { + try db.mappedUnreadCounts(chatRowIDs: sample.chatRowIDs).count + }) + results.append(try measure("mappedMessageRows.page") { + var count = 0 + for chatGUID in sample.messageChatGUIDs { + count += try db.mappedMessageRows(in: chatGUID, cursor: nil, direction: nil, limit: messageLimit).count + } + return count + }) + results.append(try measure("mappedMessageRows.rowIDs") { + try db.mappedMessageRows(rowIDs: sample.messageRowIDs).count + }) + results.append(try measure("mappedMessageRows.guids") { + try db.mappedMessageRows(guids: sample.messageGUIDs).count + }) + results.append(try measure("mappedAttachmentRows") { + try db.mappedAttachmentRows(messageRowIDs: sample.messageRowIDs).count + }) + results.append(try measure("mappedReactionRows") { + try db.mappedReactionRows(messageGUIDs: sample.messageGUIDs, chatRowIDs: sample.messageChatRowIDs).count + }) + results.append(try measure("messageUpdateCursorSnapshot") { + try db.messageUpdateCursorSnapshot().lastRowID + }) + results.append(try measure("chatStates") { + try db.chatStates().count + }) + results.append(try measure("searchMessages") { + try db.searchMessages(query: searchQuery, limit: messageLimit).count + }) + + return results + } + + private func makeSQLSample(db: IMDatabase) throws -> SQLSample { + let threadRows = try db.mappedThreadRows(cursor: nil, direction: nil, limit: maxChats) + guard !threadRows.isEmpty else { + throw BenchError.noThreads + } + + var messageChatGUIDs: [String] = [] + var messageRows: [MappedMessageRow] = [] + for threadRow in threadRows { + let rows = try db.mappedMessageRows(in: threadRow.guid, cursor: nil, direction: nil, limit: messageLimit) + guard !rows.isEmpty else { continue } + messageChatGUIDs.append(threadRow.guid) + messageRows.append(contentsOf: rows) + } + + return SQLSample( + threadRows: threadRows, + messageChatGUIDs: messageChatGUIDs, + messageRows: messageRows + ) + } + + private func runAPIBenchmarks() async throws -> [BenchmarkResult] { + let api = try PlatformAPI(accountID: "perf-bench", enforceSingleton: false) + do { + let threadPage = try await api.getThreads(folderName: "normal", pagination: nil) + let threadIDs = Array(threadPage.items.prefix(apiThreadSamples).map(\.id)) + guard !threadIDs.isEmpty else { + throw BenchError.noThreads + } + + var results: [BenchmarkResult] = [] + results.append(try await measureAsync("PlatformAPI.getThreads.firstPage") { + try await api.getThreads(folderName: "normal", pagination: nil).items.count + }) + results.append(try await measureAsync("PlatformAPI.getMessages.sampleThreads") { + var count = 0 + for threadID in threadIDs { + count += try await api.getMessages(threadID: threadID, pagination: nil).items.count + } + return count + }) + try? await api.dispose() + return results + } catch { + try? await api.dispose() + throw error + } + } + + private func measure(_ name: String, operation: () throws -> Int) throws -> BenchmarkResult { + for _ in 0.. Int) async throws -> BenchmarkResult { + for _ in 0.. URL { + let expandedPath: String + if path == "~" { + expandedPath = NSHomeDirectory() + } else if path.hasPrefix("~/") { + expandedPath = NSHomeDirectory() + String(path.dropFirst()) + } else { + expandedPath = path + } + return URL(fileURLWithPath: expandedPath, isDirectory: true) +} + +private func milliseconds(fromNanoseconds nanoseconds: UInt64) -> Double { + Double(nanoseconds) / 1_000_000 +} + +private func average(_ samples: [Double]) -> Double { + guard !samples.isEmpty else { return 0 } + return samples.reduce(0, +) / Double(samples.count) +} + +private func percentile(_ samples: [Double], _ percentile: Double) -> Double { + guard !samples.isEmpty else { return 0 } + let sorted = samples.sorted() + let index = max(0, min(sorted.count - 1, Int(ceil(Double(sorted.count) * percentile)) - 1)) + return sorted[index] +} + +private func printPretty(_ report: BenchmarkReport) { + print("iMessage perf benchmark") + print("Messages dir: \(report.metadata.messagesDir)") + print("Iterations: \(report.metadata.iterations), warmups: \(report.metadata.warmups)") + print() + printSection("SQL hot paths", report.sql.results) + print() + printSection("Platform API", report.api.results) +} + +private func printSection(_ title: String, _ results: [BenchmarkResult]) { + guard !results.isEmpty else { + print("\(title): skipped") + return + } + + print(title) + print("\(pad("name", to: 40)) \(pad("rows", to: 8)) \(pad("avg ms", to: 10)) \(pad("p50 ms", to: 10)) \(pad("p95 ms", to: 10))") + for result in results { + let row = [ + pad(result.name, to: 40), + pad(String(result.resultCount), to: 8), + pad(String(format: "%.3f", result.averageMS), to: 10), + pad(String(format: "%.3f", result.p50MS), to: 10), + pad(String(format: "%.3f", result.p95MS), to: 10), + ].joined(separator: " ") + print(row) + } +} + +private func pad(_ value: String, to width: Int) -> String { + let trimmed = value.count > width ? String(value.prefix(width - 1)) + "*" : value + return trimmed + String(repeating: " ", count: max(0, width - trimmed.count)) +} diff --git a/src/IMessage/Sources/PlatformSDK/ServerEvent.swift b/src/IMessage/Sources/PlatformSDK/ServerEvent.swift index 3b875544..9384e2f2 100644 --- a/src/IMessage/Sources/PlatformSDK/ServerEvent.swift +++ b/src/IMessage/Sources/PlatformSDK/ServerEvent.swift @@ -4,6 +4,7 @@ extension PlatformSDK { public enum ServerEventType: String { case stateSync = "state_sync" case toast + @available(*, deprecated, message: "Use state_sync message events instead.") case threadMessagesRefresh = "thread_messages_refresh" case userActivity = "user_activity" case userPresenceUpdated = "user_presence_updated" @@ -28,6 +29,7 @@ public enum ServerEvent { /// Displays user-visible text in a dismissible notification. case toast(message: String, id: String?, timeoutMilliseconds: Int?) /// A server event with type `thread_messages_refresh`. + @available(*, deprecated, message: "Use state_sync message events instead.") case refreshMessagesInThread(id: PlatformSDK.ThreadID) /// A server event with type `state_sync` that is used to `update` a /// `thread`. @@ -35,6 +37,21 @@ public enum ServerEvent { /// A server event with type `state_sync` that is used to `delete` /// one or more threads. case deleteThreads(ids: [PlatformSDK.ThreadID]) + /// A server event with type `state_sync` that is used to `upsert` + /// messages in a thread. + case upsertMessages(threadID: PlatformSDK.ThreadID, messages: [PlatformSDK.Message]) + /// A server event with type `state_sync` that is used to `update` + /// messages in a thread. + case updateMessages(threadID: PlatformSDK.ThreadID, patches: [JSONObject]) + /// A server event with type `state_sync` that is used to `delete` + /// messages in a thread. + case deleteMessages(threadID: PlatformSDK.ThreadID, ids: [PlatformSDK.MessageID]) + /// A server event with type `state_sync` that is used to `upsert` + /// reactions for a message. + case upsertMessageReactions(threadID: PlatformSDK.ThreadID, messageID: PlatformSDK.MessageID, reactions: [PlatformSDK.MessageReaction]) + /// A server event with type `state_sync` that is used to `delete` + /// reactions for a message. + case deleteMessageReactions(threadID: PlatformSDK.ThreadID, messageID: PlatformSDK.MessageID, ids: [PlatformSDK.ID]) } extension ServerEvent { @@ -65,7 +82,7 @@ extension ServerEvent { ] case let .refreshMessagesInThread(id): return [ - "type": PlatformSDK.ServerEventType.threadMessagesRefresh.rawValue, + "type": "thread_messages_refresh", "threadID": id, ] case let .stateSyncThread(id, patch): @@ -77,7 +94,7 @@ extension ServerEvent { return [ "type": PlatformSDK.ServerEventType.stateSync.rawValue, - "objectIDs": ["threadID": NSNull(), "messageID": NSNull()], + "objectIDs": JSONObject(), "objectName": "thread", "mutationType": "update", "entries": [entry], @@ -85,14 +102,69 @@ extension ServerEvent { case let .deleteThreads(ids): return [ "type": PlatformSDK.ServerEventType.stateSync.rawValue, - "objectIDs": ["threadID": NSNull(), "messageID": NSNull()], + "objectIDs": JSONObject(), "objectName": "thread", "mutationType": "delete", "entries": ids, ] + case let .upsertMessages(threadID, messages): + return messageStateSyncJSON( + threadID: threadID, + mutationType: "upsert", + entries: messages.map(\.jsonObject) + ) + case let .updateMessages(threadID, patches): + return messageStateSyncJSON( + threadID: threadID, + mutationType: "update", + entries: patches + ) + case let .deleteMessages(threadID, ids): + return messageStateSyncJSON( + threadID: threadID, + mutationType: "delete", + entries: ids + ) + case let .upsertMessageReactions(threadID, messageID, reactions): + return messageReactionStateSyncJSON( + threadID: threadID, + messageID: messageID, + mutationType: "upsert", + entries: reactions.map(\.jsonObject) + ) + case let .deleteMessageReactions(threadID, messageID, ids): + return messageReactionStateSyncJSON( + threadID: threadID, + messageID: messageID, + mutationType: "delete", + entries: ids + ) } } + private func messageStateSyncJSON(threadID: PlatformSDK.ThreadID, mutationType: String, entries: Any) -> JSONObject { + [ + "type": PlatformSDK.ServerEventType.stateSync.rawValue, + "objectIDs": ["threadID": threadID], + "objectName": "message", + "mutationType": mutationType, + "entries": entries, + ] + } + + private func messageReactionStateSyncJSON(threadID: PlatformSDK.ThreadID, messageID: PlatformSDK.MessageID, mutationType: String, entries: Any) -> JSONObject { + [ + "type": PlatformSDK.ServerEventType.stateSync.rawValue, + "objectIDs": [ + "threadID": threadID, + "messageID": messageID, + ], + "objectName": "message_reaction", + "mutationType": mutationType, + "entries": entries, + ] + } + private func jsonObjectValue(_ value: Any) -> Any { switch value { case is String, is Bool, is Int, is Double, is Float, is NSNull: diff --git a/src/api.ts b/src/api.ts index 8e610a18..a4b0cb4a 100644 --- a/src/api.ts +++ b/src/api.ts @@ -10,7 +10,7 @@ import { shellExec } from './util' import imessage, { type NativeMacPermissionAuthStatus, type NativePlatformAPI } from './IMessage/lib' import { makeJSONPersistence, Persistence } from './persistence' import { appleDateToMillisSinceEpoch, makeAppleDate } from './time' -import { parseSwiftMessageAPIJSON } from './swift-json' +import { parseSwiftMessageAPIJSON, reviveSwiftMessageAPIValue } from './swift-json' imessage.isLoggingEnabled = texts.isLoggingEnabled @@ -107,11 +107,11 @@ export default class AppleiMessage implements PlatformAPI { subscribeToEvents = async (onEvent: OnServerEventCallback): Promise => { this.onEvent = (events: ServerEvent[]) => { const evs: ServerEvent[] = [] - events.forEach(ev => { - if (ev.type === ServerEventType.TOAST) { - texts.Sentry.captureMessage(`iMessage: ${ev.toast.text}`) + events.forEach(event => { + if (event.type === ServerEventType.TOAST) { + texts.Sentry.captureMessage(`iMessage: ${event.toast.text}`) } else { - evs.push(ev) + evs.push(reviveSwiftMessageAPIValue(event)) } }) onEvent(evs) diff --git a/src/swift-json.test.ts b/src/swift-json.test.ts new file mode 100644 index 00000000..21288737 --- /dev/null +++ b/src/swift-json.test.ts @@ -0,0 +1,62 @@ +import { parseSwiftMessageAPIJSON, reviveSwiftMessageAPIValue } from './swift-json' + +describe('swift-json', () => { + it('converts message date fields in parsed JSON', () => { + const parsed = parseSwiftMessageAPIJSON<{ + timestamp: Date + editedTimestamp: Date + seen: Date + sortKey: number + }>(JSON.stringify({ + timestamp: 1, + editedTimestamp: 2, + seen: 3, + sortKey: 4, + })) + + expect(parsed.timestamp).toEqual(new Date(1)) + expect(parsed.editedTimestamp).toEqual(new Date(2)) + expect(parsed.seen).toEqual(new Date(3)) + expect(parsed.sortKey).toBe(4) + }) + + it('converts event date fields in already-parsed Swift values', () => { + const revived = reviveSwiftMessageAPIValue({ + entries: [{ + id: 'message-id', + timestamp: 1, + editedTimestamp: 2, + seen: { + alice: 3, + bob: true, + }, + sortKey: 4, + }], + }) + + expect(revived.entries[0].timestamp).toEqual(new Date(1)) + expect(revived.entries[0].editedTimestamp).toEqual(new Date(2)) + expect(revived.entries[0].seen).toEqual({ + alice: new Date(3), + bob: true, + }) + expect(revived.entries[0].sortKey).toBe(4) + }) + + it('mutates already-parsed Swift values while reviving', () => { + const event = { + entries: [{ + timestamp: 1, + seen: { alice: 2 }, + }], + } + + const revived = reviveSwiftMessageAPIValue(event) + + expect(revived).toBe(event) + expect(event.entries[0].timestamp).toEqual(new Date(1)) + expect(event.entries[0].seen.alice).toEqual(new Date(2)) + expect(revived.entries[0].timestamp).toEqual(new Date(1)) + expect(revived.entries[0].seen.alice).toEqual(new Date(2)) + }) +}) diff --git a/src/swift-json.ts b/src/swift-json.ts index 5efbe2b4..898ef776 100644 --- a/src/swift-json.ts +++ b/src/swift-json.ts @@ -1,7 +1,57 @@ -const SWIFT_DATE_FIELDS = new Set(['timestamp', 'seen', 'editedTimestamp']) +const SWIFT_DATE_FIELDS = new Set([ + 'createdAt', + 'editedTimestamp', + 'lastActive', + 'mutedUntil', + 'seen', + 'timestamp', +]) -export const swiftMapperReviver = (key: string, value: unknown): unknown => - SWIFT_DATE_FIELDS.has(key) && typeof value === 'number' ? new Date(value) : value +const isMutableRecord = (value: unknown): value is Record => + !!value && typeof value === 'object' && !Array.isArray(value) && !(value instanceof Date) + +export const swiftMapperReviver = (key: string, value: unknown): unknown => { + if (SWIFT_DATE_FIELDS.has(key) && typeof value === 'number') return new Date(value) + if (key === 'seen' && isMutableRecord(value)) { + const seenByParticipantID = value + Object.entries(seenByParticipantID).forEach(([participantID, seenValue]) => { + if (typeof seenValue === 'number') seenByParticipantID[participantID] = new Date(seenValue) + }) + } + return value +} + +const reviveSwiftDateFields = (record: Record): void => { + SWIFT_DATE_FIELDS.forEach(field => { + if (field in record) record[field] = swiftMapperReviver(field, record[field]) + }) +} + +const reviveSwiftEventEntry = (entry: unknown): void => { + if (isMutableRecord(entry)) reviveSwiftDateFields(entry) +} + +export const reviveSwiftMessageAPIValue = (value: T): T => { + // Intentionally mutates already-parsed Swift bridge payloads in place. These + // values are transient event objects. Keep the work targeted to the event + // envelope and state-sync entries instead of walking attachments/extras. + if (Array.isArray(value)) { + value.forEach(reviveSwiftEventEntry) + return value + } + if (!isMutableRecord(value)) return swiftMapperReviver('', value) as T + + reviveSwiftDateFields(value) + + if (Array.isArray(value.entries)) { + value.entries.forEach(reviveSwiftEventEntry) + } + if (isMutableRecord(value.presence)) { + reviveSwiftDateFields(value.presence) + } + + return value +} export const parseSwiftMessageAPIJSON = (json: string): T => JSON.parse(json, swiftMapperReviver) as T diff --git a/todos.md b/todos.md index 33956ed8..862bae81 100644 --- a/todos.md +++ b/todos.md @@ -29,12 +29,6 @@ - [ ] one off command to print presence (dnd / dnd w notify) and typing status - [ ] tests -- instead of `thread_messages_refresh` - - [ ] new incoming messages should be state sync message upserts - - [ ] new added/removed reactions should be state sync message upserts/deletes (for the hidden reaction message) and a state sync message update (for the og message) - - [ ] messages edited should be state sync message updates - - [ ] messages getting read should be state sync message updates - ### Parity - [ ] add delete message for me command @@ -64,3 +58,8 @@ - [x] add undo send CLI command - [x] fix notify anyway on tahoe - [x] fix unmute thread on tahoe +- instead of `thread_messages_refresh` + - [x] new incoming messages should be state sync message upserts + - [x] new added/removed reactions should be state sync message upserts/deletes (for the hidden reaction message) and a state sync message update (for the og message) + - [x] messages edited should be state sync message updates + - [x] messages getting read should be state sync message updates