diff --git a/.github/ISSUE_TEMPLATE/00-bug_report_zh.yml b/.github/ISSUE_TEMPLATE/00-bug_report_zh.yml index c6aa3fc62..ec18a650b 100644 --- a/.github/ISSUE_TEMPLATE/00-bug_report_zh.yml +++ b/.github/ISSUE_TEMPLATE/00-bug_report_zh.yml @@ -82,3 +82,9 @@ body: label: 复现链接(可选) description: | 请提供能复现此问题的链接。 + - type: textarea + id: aigenerated + attributes: + label: AI生成内容(可选) + description: | + 如果此问题是由AI辅助您发现的,请提供全部聊天记录,包括使用的模型信息。 diff --git a/.github/ISSUE_TEMPLATE/01-bug_report_en.yml b/.github/ISSUE_TEMPLATE/01-bug_report_en.yml index d99968d90..0085d4de5 100644 --- a/.github/ISSUE_TEMPLATE/01-bug_report_en.yml +++ b/.github/ISSUE_TEMPLATE/01-bug_report_en.yml @@ -82,3 +82,9 @@ body: label: Reproduction Link (optional) description: | Please provide a link to a repo or page that can reproduce this issue. + - type: textarea + id: aigenerated + attributes: + label: AI Generated Content (optional) + description: | + If this issue was identified with the assistance of AI, please provide the complete chat log, including information about the model used. diff --git a/.github/ISSUE_TEMPLATE/02-feature_request_zh.yml b/.github/ISSUE_TEMPLATE/02-feature_request_zh.yml index 8339d947f..907240261 100644 --- a/.github/ISSUE_TEMPLATE/02-feature_request_zh.yml +++ b/.github/ISSUE_TEMPLATE/02-feature_request_zh.yml @@ -48,3 +48,9 @@ body: label: 附加信息 description: | 相关的任何其他上下文或截图,或者你觉得有帮助的信息 + - type: textarea + id: aigenerated + attributes: + label: AI生成内容(可选) + description: | + 如果此请求是由AI辅助您提交的,请提供全部聊天记录,包括使用的模型信息。 diff --git a/.github/ISSUE_TEMPLATE/03-feature_request_en.yml b/.github/ISSUE_TEMPLATE/03-feature_request_en.yml index 41c9990cf..393118592 100644 --- a/.github/ISSUE_TEMPLATE/03-feature_request_en.yml +++ b/.github/ISSUE_TEMPLATE/03-feature_request_en.yml @@ -48,3 +48,9 @@ body: label: Additional Information description: | Any other context or screenshots related to this feature request, or information you find helpful. + - type: textarea + id: aigenerated + attributes: + label: AI Generated Content (optional) + description: | + If this request was submitted with the assistance of an AI, please provide the complete chat log, including information about the model used. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index f1687eabf..58e122837 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,58 +1,114 @@ + +## Summary / 摘要 + -## Description / 描述 + - +- 列出用户可感知的行为变化。 +- 列出重要实现变化。 +- 如涉及配置、存储、API 或兼容性变化,请明确说明。 +--> -## Motivation and Context / 背景 +- [ ] This PR has breaking changes. + / 此 PR 包含破坏性变更。 +- [ ] This PR changes public API, config, storage format, or migration behavior. + / 此 PR 修改了公开 API、配置、存储格式或迁移行为。 +- [ ] This PR requires corresponding changes in related repositories. + / 此 PR 需要关联仓库同步修改。 - - +Related repository PRs / 关联仓库 PR: - - +- OpenList-Frontend: +- OpenList-Docs: -Closes #XXXX +## Related Issues / 关联 Issue - - + -Relates to #XXXX +## Testing / 测试 -## How Has This Been Tested? / 测试 + - - +- [ ] `go test ./...` +- [ ] Manual test / 手动测试: ## Checklist / 检查清单 - - - - - - -- [ ] I have read the [CONTRIBUTING](https://github.com/OpenListTeam/OpenList/blob/main/CONTRIBUTING.md) document. - 我已阅读 [CONTRIBUTING](https://github.com/OpenListTeam/OpenList/blob/main/CONTRIBUTING.md) 文档。 -- [ ] I have formatted my code with `go fmt` or [prettier](https://prettier.io/). - 我已使用 `go fmt` 或 [prettier](https://prettier.io/) 格式化提交的代码。 -- [ ] I have added appropriate labels to this PR (or mentioned needed labels in the description if lacking permissions). - 我已为此 PR 添加了适当的标签(如无权限或需要的标签不存在,请在描述中说明,管理员将后续处理)。 -- [ ] I have requested review from relevant code authors using the "Request review" feature when applicable. - 我已在适当情况下使用"Request review"功能请求相关代码作者进行审查。 -- [ ] I have updated the repository accordingly (If it’s needed). - 我已相应更新了相关仓库(若适用)。 - - [ ] [OpenList-Frontend](https://github.com/OpenListTeam/OpenList-Frontend) #XXXX - - [ ] [OpenList-Docs](https://github.com/OpenListTeam/OpenList-Docs) #XXXX +- [ ] I have read [CONTRIBUTING](https://github.com/OpenListTeam/OpenList/blob/main/CONTRIBUTING.md). + / 我已阅读 [CONTRIBUTING](https://github.com/OpenListTeam/OpenList/blob/main/CONTRIBUTING.md)。 +- [ ] I confirm this contribution follows the repository license, contribution policy, and code of conduct. + / 我确认此贡献符合仓库许可证、贡献规范和行为准则。 +- [ ] I have formatted the changed code with `gofmt`, `go fmt`, or `prettier` where applicable. + / 我已按适用情况使用 `gofmt`、`go fmt` 或 `prettier` 格式化变更代码。 +- [ ] I have requested review from relevant maintainers or code owners where applicable. + / 我已在适用情况下请求相关维护者或代码所有者审查。 + +## AI Disclosure / AI 使用声明 + + + +- [ ] This PR includes AI-assisted content. + / 此 PR 包含 AI 辅助内容。 + +Tools used / 使用工具: + +- [ ] ChatGPT +- [ ] Codex +- [ ] GitHub Copilot +- [ ] Claude +- [ ] Gemini +- [ ] Other (please specify) / 其他(请注明): + +Usage scope / 使用范围: + +- [ ] Code generation / 代码生成 +- [ ] Refactoring / 重构 +- [ ] Documentation / 文档 +- [ ] Tests / 测试 +- [ ] Translation / 翻译 +- [ ] Review assistance / 审查辅助 + +- [ ] I have reviewed and validated all AI-assisted content included in this PR. + / 我已审核并验证此 PR 中的所有 AI 辅助内容。 +- [ ] I have ensured that all AI-assisted commits include `Co-Authored-By` attribution. + / 我已确保所有 AI 辅助提交都包含 `Co-Authored-By` 归属信息。 +- [ ] I can reproduce all AI-assisted content included in this PR without any AI tools. + / 我可以在没有任何 AI 工具的情况下重现此 PR 中包含的所有 AI 辅助内容。 diff --git a/.github/workflows/beta_release.yml b/.github/workflows/beta_release.yml index d5817e6b4..9ea36c4e5 100644 --- a/.github/workflows/beta_release.yml +++ b/.github/workflows/beta_release.yml @@ -48,7 +48,7 @@ jobs: tag_name: beta - name: Upload assets to github artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: beta changelog path: ${{ github.workspace }}/CHANGELOG.md @@ -110,7 +110,7 @@ jobs: fetch-depth: 0 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: "1.25.0" @@ -137,6 +137,7 @@ jobs: github.com/OpenListTeam/OpenList/v4/internal/conf.GitCommit=$git_commit github.com/OpenListTeam/OpenList/v4/internal/conf.Version=$tag github.com/OpenListTeam/OpenList/v4/internal/conf.WebVersion=rolling + github.com/OpenListTeam/OpenList/v4/internal/conf.FrontendRepoDefault=${{ vars.FRONTEND_REPO || 'Ironboxplus/OpenList-Frontend' }} env: GOFLAGS: ${{ matrix.goflags }} @@ -182,7 +183,7 @@ jobs: echo "cleaned_target=$CLEANED_TARGET" >> $GITHUB_ENV - name: Upload assets to github artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: beta builds for ${{ env.cleaned_target }} path: ${{ github.workspace }}/build/compress/* diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a3a501ffa..fa1528ad7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,6 +9,9 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true +env: + FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true + jobs: build: strategy: @@ -31,7 +34,7 @@ jobs: id: short-sha - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: "1.25.0" @@ -55,7 +58,8 @@ jobs: github.com/OpenListTeam/OpenList/v4/internal/conf.GitAuthor=The OpenList Projects Contributors github.com/OpenListTeam/OpenList/v4/internal/conf.GitCommit=$git_commit github.com/OpenListTeam/OpenList/v4/internal/conf.Version=$tag - github.com/OpenListTeam/OpenList/v4/internal/conf.WebVersion=rolling + github.com/OpenListTeam/OpenList/v4/internal/conf.WebVersion=latest + github.com/OpenListTeam/OpenList/v4/internal/conf.FrontendRepoDefault=${{ vars.FRONTEND_REPO || 'Ironboxplus/OpenList-Frontend' }} output: openlist$ext - name: Verify musl binary is static @@ -69,7 +73,7 @@ jobs: fi - name: Upload artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: openlist_${{ steps.short-sha.outputs.sha }}_${{ matrix.target }} path: build/* diff --git a/.github/workflows/issue_pr_comment.yml b/.github/workflows/issue_pr_comment.yml index 1b51e23b9..bc29e6bf6 100644 --- a/.github/workflows/issue_pr_comment.yml +++ b/.github/workflows/issue_pr_comment.yml @@ -16,7 +16,7 @@ jobs: if: github.event_name == 'issues' steps: - name: Check issue for unchecked tasks and reply - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | let comment = ""; @@ -81,7 +81,7 @@ jobs: if: github.event_name == 'pull_request' steps: - name: Check PR title for required prefix and comment - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | const title = context.payload.pull_request.title || ""; diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2fa13af18..76152ef24 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -44,7 +44,7 @@ jobs: swap-storage: true - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: '1.25.0' diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml index 80bdf9e3c..7503a3eda 100644 --- a/.github/workflows/release_docker.yml +++ b/.github/workflows/release_docker.yml @@ -45,13 +45,13 @@ jobs: - name: Checkout uses: actions/checkout@v6 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: go-version: '1.25.0' - name: Cache Musl id: cache-musl - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: build/musl-libs key: docker-musl-libs-v2 @@ -69,7 +69,7 @@ jobs: FRONTEND_REPO: ${{ vars.FRONTEND_REPO }} - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ${{ env.ARTIFACT_NAME }} overwrite: true @@ -85,13 +85,13 @@ jobs: - name: Checkout uses: actions/checkout@v6 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: go-version: '1.25.0' - name: Cache Musl id: cache-musl - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: build/musl-libs key: docker-musl-libs-v2 @@ -109,7 +109,7 @@ jobs: FRONTEND_REPO: ${{ vars.FRONTEND_REPO }} - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ${{ env.ARTIFACT_NAME_LITE }} overwrite: true @@ -147,7 +147,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@v6 - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v7 with: name: ${{ env.ARTIFACT_NAME }} path: 'build/' @@ -231,7 +231,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@v6 - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v7 with: name: ${{ env.ARTIFACT_NAME_LITE }} path: 'build/' diff --git a/.github/workflows/test_docker.yml b/.github/workflows/test_docker.yml index 16c299401..27a81137e 100644 --- a/.github/workflows/test_docker.yml +++ b/.github/workflows/test_docker.yml @@ -1,150 +1,195 @@ name: Beta Release (Docker) - on: workflow_dispatch: + inputs: + frontend_repo: + description: 'Frontend repo, e.g. Ironboxplus/OpenList-Frontend' + required: false + default: 'Ironboxplus/OpenList-Frontend' + type: string + frontend_channel: + description: 'Frontend release channel to build' + required: false + default: rolling + type: choice + options: + - rolling + - latest + - both push: branches: - - main - pull_request: - branches: - - main + - feat/dynamic-frontend + concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true env: - DOCKERHUB_ORG_NAME: ${{ vars.DOCKERHUB_ORG_NAME || 'openlistteam' }} - GHCR_ORG_NAME: ${{ vars.GHCR_ORG_NAME || 'openlistteam' }} - IMAGE_NAME: openlist-git - IMAGE_NAME_DOCKERHUB: openlist + FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true + GHCR_ORG_NAME: ${{ vars.GHCR_ORG_NAME || 'ironboxplus' }} # 👈 最好改成你的用户名,防止推错地方 + FRONTEND_REPO: ${{ github.event.inputs.frontend_repo || vars.FRONTEND_REPO || 'Ironboxplus/OpenList-Frontend' }} REGISTRY: ghcr.io - ARTIFACT_NAME: 'binaries_docker_release' - RELEASE_PLATFORMS: 'linux/amd64,linux/arm64,linux/arm/v7,linux/386,linux/arm/v6,linux/ppc64le,linux/riscv64,linux/loong64' ### Temporarily disable Docker builds for linux/s390x architectures for unknown reasons. - IMAGE_PUSH: ${{ github.event_name == 'push' }} - IMAGE_TAGS_BETA: | - type=ref,event=pr - type=raw,value=beta,enable={{is_default_branch}} + ARTIFACT_NAME_PREFIX: 'binaries_docker_release' + # 👇 关键修改:只保留 linux/amd64,删掉后面一长串 + RELEASE_PLATFORMS: 'linux/amd64' + # 👇 关键修改:强制允许推送,不用管是不是 push 事件 + IMAGE_PUSH: 'true' + # 👇 使用默认的前端仓库 (Ironboxplus/OpenList-Frontend) + # FRONTEND_REPO: 'Ironboxplus/OpenList-Frontend' jobs: build_binary: - name: Build Binaries for Docker Release (Beta) + name: Build Binaries (x64, front-${{ matrix.frontend_channel }}) runs-on: ubuntu-latest + strategy: + matrix: + frontend_channel: ${{ github.event_name == 'workflow_dispatch' && (github.event.inputs.frontend_channel == 'both' && fromJSON('["rolling","latest"]') || fromJSON(format('["{0}"]', github.event.inputs.frontend_channel))) || fromJSON('["rolling"]') }} steps: - name: Checkout uses: actions/checkout@v6 - - uses: actions/setup-go@v5 + - name: Setup Go + uses: actions/setup-go@v6 with: go-version: '1.25.0' + cache: true + cache-dependency-path: go.sum + + - name: Get Frontend Cache Version + id: frontend-cache + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + WEB_VERSION: ${{ matrix.frontend_channel }} + run: | + frontend_repo="${{ env.FRONTEND_REPO }}" + web_version="$WEB_VERSION" + github_auth_args=() + + if [ -n "$GH_TOKEN" ]; then + github_auth_args=(-H "Authorization: Bearer $GH_TOKEN") + fi + + if [ "$web_version" = "latest" ]; then + frontend_version=$(curl -fsSL "${github_auth_args[@]}" "https://api.github.com/repos/$frontend_repo/releases/latest" | jq -r '.tag_name') + else + # For rolling/dev channels, resolve the actual commit to bust stale caches + release_json=$(curl -fsSL "${github_auth_args[@]}" "https://api.github.com/repos/$frontend_repo/releases/tags/$web_version" 2>/dev/null || echo '{}') + asset_name=$(echo "$release_json" | jq -r '.assets[0].name // empty') + frontend_version="${web_version}-${asset_name:-unknown}" + fi + + echo "repo=$frontend_repo" >> "$GITHUB_OUTPUT" + echo "version=$frontend_version" >> "$GITHUB_OUTPUT" + echo "Frontend repo: $frontend_repo" + echo "Frontend cache version: $frontend_version" + + - name: Cache Frontend + id: cache-frontend + uses: actions/cache@v5 + with: + path: public/dist + key: frontend-${{ steps.frontend-cache.outputs.repo }}-${{ steps.frontend-cache.outputs.version }} + restore-keys: | + frontend-${{ steps.frontend-cache.outputs.repo }}- - name: Cache Musl id: cache-musl - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: build/musl-libs key: docker-musl-libs-v2 - name: Download Musl Library if: steps.cache-musl.outputs.cache-hit != 'true' - run: bash build.sh prepare docker-multiplatform env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: bash build.sh prepare docker-multiplatform - - name: Build go binary (beta) + - name: Build go binary run: bash build.sh beta docker-multiplatform env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - FRONTEND_REPO: ${{ vars.FRONTEND_REPO }} + WEB_VERSION: ${{ matrix.frontend_channel }} + FRONTEND_REPO: ${{ env.FRONTEND_REPO }} + SKIP_FRONTEND_FETCH: ${{ steps.cache-frontend.outputs.cache-hit == 'true' && 'true' || 'false' }} - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: - name: ${{ env.ARTIFACT_NAME }} + name: ${{ env.ARTIFACT_NAME_PREFIX }}-${{ matrix.frontend_channel }} overwrite: true - path: | - build/ - !build/*.tgz - !build/musl-libs/** + path: build/linux/amd64/openlist release_docker: needs: build_binary - name: Release Docker image (Beta) + name: Release Docker (x64, front-${{ matrix.frontend_channel }}) runs-on: ubuntu-latest permissions: packages: write strategy: matrix: + frontend_channel: ${{ github.event_name == 'workflow_dispatch' && (github.event.inputs.frontend_channel == 'both' && fromJSON('["rolling","latest"]') || fromJSON(format('["{0}"]', github.event.inputs.frontend_channel))) || fromJSON('["rolling"]') }} + # 四种变体,各自独立 image 名(推送到独立 GHCR repo) image: ["latest", "ffmpeg", "aria2", "aio"] include: - image: "latest" base_image_tag: "base" build_arg: "" - tag_favor: "" + image_name: "openlist" - image: "ffmpeg" base_image_tag: "ffmpeg" build_arg: INSTALL_FFMPEG=true - tag_favor: "suffix=-ffmpeg,onlatest=true" + image_name: "openlist-ffmpeg" - image: "aria2" base_image_tag: "aria2" build_arg: INSTALL_ARIA2=true - tag_favor: "suffix=-aria2,onlatest=true" + image_name: "openlist-aria2" - image: "aio" base_image_tag: "aio" build_arg: | INSTALL_FFMPEG=true INSTALL_ARIA2=true - tag_favor: "suffix=-aio,onlatest=true" + image_name: "openlist-aio" steps: - name: Checkout uses: actions/checkout@v6 - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v7 with: - name: ${{ env.ARTIFACT_NAME }} - path: 'build/' - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - + name: ${{ env.ARTIFACT_NAME_PREFIX }}-${{ matrix.frontend_channel }} + path: 'build/linux/amd64' - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + # 👇 只保留 GitHub 登录,删除了 DockerHub 登录 - name: Login to GitHub Container Registry - if: env.IMAGE_PUSH == 'true' uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Login to DockerHub Container Registry - if: env.IMAGE_PUSH == 'true' - uses: docker/login-action@v3 - with: - username: ${{ vars.DOCKERHUB_ORG_NAME_BACKUP || env.DOCKERHUB_ORG_NAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Docker meta id: meta uses: docker/metadata-action@v5 with: images: | - ${{ env.REGISTRY }}/${{ env.GHCR_ORG_NAME }}/${{ env.IMAGE_NAME }} - ${{ env.DOCKERHUB_ORG_NAME }}/${{ env.IMAGE_NAME_DOCKERHUB }} - tags: ${{ env.IMAGE_TAGS_BETA }} - flavor: | - ${{ matrix.tag_favor }} + ${{ env.REGISTRY }}/${{ env.GHCR_ORG_NAME }}/${{ matrix.image_name }} + tags: | + type=raw,value=front-${{ matrix.frontend_channel }} + type=raw,value=latest,enable=${{ matrix.frontend_channel == 'latest' }} - name: Build and push - id: docker_build uses: docker/build-push-action@v6 with: context: . file: Dockerfile.ci - push: ${{ env.IMAGE_PUSH == 'true' }} + push: true build-args: | BASE_IMAGE_TAG=${{ matrix.base_image_tag }} ${{ matrix.build_arg }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} platforms: ${{ env.RELEASE_PLATFORMS }} + cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.GHCR_ORG_NAME }}/${{ matrix.image_name }}:buildcache-front-${{ matrix.frontend_channel }} + cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.GHCR_ORG_NAME }}/${{ matrix.image_name }}:buildcache-front-${{ matrix.frontend_channel }},mode=max diff --git a/.gitignore b/.gitignore index 1d71f0d60..d42155110 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ .DS_Store output/ /dist/ - +.omx # Binaries for programs and plugins *.exe *.exe~ @@ -31,4 +31,5 @@ output/ /public/dist/* /!public/dist/README.md -.VSCodeCounter \ No newline at end of file +.VSCodeCounter +nul diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..55cdf77a8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,430 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Core Development Principles + +1. **最小代码改动原则** (Minimum code changes): Make the smallest change necessary to achieve the goal +2. **不缓存整个文件原则** (No full file caching for seekable streams): For SeekableStream, use RangeRead instead of caching entire file +3. **必要情况下可以多遍上传原则** (Multi-pass upload when necessary): If rapid upload fails, fall back to normal upload + +## Build and Development Commands + +```bash +# Development +go run main.go # Run backend server (default port 5244) +air # Hot reload during development (uses .air.toml) +./build.sh dev # Build development version with frontend +./build.sh release # Build release version + +# Testing +go test ./... # Run all tests +go test ./drivers/115_open/ -v # Run tests for a specific driver +go test ./drivers/115_open/ -run TestCheckUploadCallback -v # Run a single test +go build ./drivers/115_open/... # Quick compile check for a package + +# Docker +docker-compose up # Run with docker-compose +docker build -f Dockerfile . # Build docker image +``` + +**Build Script Details** (`build.sh`): +- Fetches frontend from `$FRONTEND_REPO` (default: `Ironboxplus/OpenList-Frontend`) releases and embeds into `public/dist/` +- Injects version info via ldflags: `-X "github.com/OpenListTeam/OpenList/v4/internal/conf.BuiltAt=$(date +'%F %T %z')"` +- Supports `dev`, `beta`, and release builds +- Downloads prebuilt frontend distribution automatically + +**Go Version**: Requires Go 1.24+ (CI uses 1.25.0) + +**Module Replacements** (`go.mod`): Some dependencies use `replace` directives pointing to forks (e.g., `115-sdk-go` → `Ironboxplus/115-sdk-go`). When modifying SDK behavior, check if there's a local fork to edit. + +## Architecture Overview + +### Driver System (Storage Abstraction) + +OpenList uses a **driver pattern** to support 70+ cloud storage providers. Each driver implements the core `Driver` interface. + +**Location**: `drivers/*/` + +**Core Interfaces** (`internal/driver/driver.go`): +- `Reader`: List directories, generate download links (REQUIRED) +- `Writer`: Upload, delete, move files (optional) +- `ArchiveDriver`: Extract archives (optional) +- `LinkCacheModeResolver`: Custom cache TTL strategies (optional) + +**Driver Registration Pattern**: +```go +// In drivers/your_driver/meta.go +var config = driver.Config{ + Name: "YourDriver", + LocalSort: false, + NoCache: false, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &YourDriver{} + }) +} +``` + +**Adding a New Driver**: +1. Copy `drivers/template/` to `drivers/your_driver/` +2. Implement `List()` and `Link()` methods (required) +3. Define `Addition` struct with configuration fields using struct tags: + - `json:"field_name"` - JSON field name + - `type:"select"` - Input type (select, string, text, bool, number) + - `required:"true"` - Required field + - `options:"a,b,c"` - Dropdown options + - `default:"value"` - Default value +4. Register driver in `init()` function + +**Example Driver Structure**: +```go +type YourDriver struct { + model.Storage + Addition + client *YourClient +} + +func (d *YourDriver) Init(ctx context.Context) error { + // Initialize client, login, etc. +} + +func (d *YourDriver) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // Return list of files/folders +} + +func (d *YourDriver) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + // Return download URL or RangeReader +} +``` + +### Request Flow + +``` +HTTP Request (Gin Router) + ↓ +Middleware (Auth, CORS, Logging) + ↓ +Handler (server/handles/) + ↓ +fs.List/Get/Link (mount path → storage path conversion) + ↓ +op.List/Get/Link (caching, driver lookup) + ↓ +Driver.List/Link (storage-specific API calls) + ↓ +Response (JSON / Proxy / Redirect) +``` + +### Internal Package Structure + +| Package | Purpose | +|---------|---------| +| `bootstrap/` | Initialization sequence: config, DB, storages, servers | +| `conf/` | Configuration management | +| `db/` | Database models (SQLite/MySQL/Postgres) | +| `driver/` | Driver interface definitions | +| `fs/` | Mount path abstraction (converts `/mount/path` to storage + path) | +| `op/` | Core operations with caching and driver management | +| `stream/` | Streaming, range readers, link refresh, rate limiting | +| `model/` | Data models (Obj, Link, Storage, User) | +| `cache/` | Multi-level caching (directories, links, users, settings) | +| `net/` | HTTP utilities, proxy config, download manager | + +### Link Generation and Caching + +**Link Types**: +1. **Direct URL** (`link.URL`): Simple redirect to storage provider +2. **RangeReader** (`link.RangeReader`): Custom streaming implementation +3. **Refreshable Link** (`link.Refresher`): Auto-refresh on expiration + +**Cache System** (`internal/op/cache.go`): +- **Directory Cache**: Stores file listings with configurable TTL +- **Link Cache**: Stores download URLs (30min default) +- **User Cache**: Authentication data (1hr default) +- **Custom Policies**: Pattern-based TTL via `pattern:ttl` format + +**Cache Key Pattern**: `{storageMountPath}/{relativePath}` + +**Invalidation**: Recursive tree deletion for directory operations + +### Range Reader and Streaming + +**Location**: `internal/stream/` + +**Purpose**: Handle partial content requests (HTTP 206), multi-threaded downloads, and link refresh during streaming. + +**Key Components**: + +1. **RangeReaderIF**: Core interface for range-based reading + ```go + type RangeReaderIF interface { + RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) + } + ``` + +2. **RefreshableRangeReader**: Wraps RangeReader with automatic link refresh + - Detects expired links via error strings or HTTP status codes (401, 403, 410, 500) + - Calls `link.Refresher(ctx)` to get new link + - Resumes download from current byte position + - Max 3 refresh attempts to prevent infinite loops + +3. **Multi-threaded Downloader** (`internal/net/downloader.go`): + - Splits file into parts based on `Concurrency` and `PartSize` + - Downloads parts in parallel + - Assembles final stream + +**Stream Types and Reader Management**: + +⚠️ **CRITICAL**: SeekableStream.Reader must NEVER be created early! + +- **FileStream**: One-time sequential stream (e.g., HTTP body) + - `Reader` is set at creation and consumed sequentially + - Cannot be rewound or re-read + +- **SeekableStream**: Reusable stream with RangeRead capability + - Has `rangeReader` for creating new readers on-demand + - `Reader` should ONLY be created when actually needed for sequential reading + - **DO NOT create Reader early** - use lazy initialization via `generateReader()` + +**Common Pitfall - Early Reader Creation**: +```go +// ❌ WRONG: Creating Reader early +if _, ok := rr.(*model.FileRangeReader); ok { + rc, _ := rr.RangeRead(ctx, http_range.Range{Length: -1}) + fs.Reader = rc // This will be consumed by intermediate operations! +} + +// ✅ CORRECT: Let generateReader() create it on-demand +// Reader will be created only when Read() is called +return &SeekableStream{FileStream: fs, rangeReader: rr}, nil +``` + +**Why This Matters**: +- Hash calculation uses `StreamHashFile()` which reads the file via RangeRead +- If Reader is created early, it may be at EOF when HTTP upload actually needs it +- Result: `http: ContentLength=X with Body length 0` error + +**Hash Calculation for Uploads**: +```go +// For SeekableStream: Use RangeRead to avoid consuming Reader +if _, ok := file.(*SeekableStream); ok { + hash, err = stream.StreamHashFile(file, utils.MD5, 40, &up) + // StreamHashFile uses RangeRead internally, Reader remains unused +} + +// For FileStream: Must cache first, then calculate hash +_, hash, err = stream.CacheFullAndHash(file, &up, utils.MD5) +``` + +**Link Refresh Pattern**: +```go +// In op.Link(), a refresher is automatically attached +link.Refresher = func(refreshCtx context.Context) (*model.Link, model.Obj, error) { + // Get fresh link from storage driver + file, err := GetUnwrap(refreshCtx, storage, path) + newLink, err := storage.Link(refreshCtx, file, args) + return newLink, file, nil +} + +// RefreshableRangeReader uses this during streaming +if IsLinkExpiredError(err) && r.link.Refresher != nil { + newLink, _, err := r.link.Refresher(ctx) + // Resume from current position +} +``` + +**Proxy Function** (`server/common/proxy.go`): + +Handles multiple scenarios: +1. Multi-threaded download (`link.Concurrency > 0`) +2. Direct RangeReader (`link.RangeReader != nil`) +3. Refreshable link (`link.Refresher != nil`) ← Wraps with RefreshableRangeReader +4. Transparent proxy (forwards to `link.URL`) + +### Frontend Dist Serving + +**Location**: `server/static/static.go`, `internal/frontend/` + +The frontend dist has two sources, with a strict priority: + +1. **Embedded dist** (`public/dist/` via `go:embed`): Baked into the binary at build time. Always used on startup. +2. **Dynamic dist** (fetched by watcher): The `frontend.Watcher` checks GitHub every 30 minutes for a newer rolling release. If found, it downloads to `data/frontend_dist/` and hot-swaps the serving FS via `ReloadStatic()`. + +**Key design rule**: `initStatic()` always starts with the embedded dist (or `dist_dir` if configured). The cached dynamic dist in the data volume is never read on startup — only the watcher can activate it after verifying a newer version exists on GitHub. This prevents stale cache from overriding a newer Docker image. + +**Configuration** (`config.json`): +- `dist_dir`: Override with a custom local directory (highest priority, skips embedded) +- `frontend_repo`: GitHub repo for the watcher to check (default: `Ironboxplus/OpenList-Frontend`) + +**Startup flow**: +``` +initStatic() → embedded dist (or dist_dir) + ↓ +StartWatcher(ReloadStatic) → background goroutine + ↓ (every 30min) +FetchFromRolling() → compare cache version vs GitHub rolling tag commit + ↓ (if newer) +downloadAndExtract() → atomic swap in data/frontend_dist/dist/ + ↓ +ReloadStatic() → swap staticFS to new dist, re-render index.html +``` + +### Startup Sequence + +**Location**: `internal/bootstrap/run.go` + +Order of initialization: +1. `InitConfig()` - Load config, environment variables +2. `Log()` - Initialize logging +3. `InitDB()` - Connect to database +4. `data.InitData()` - Initialize default data +5. `LoadStorages()` - Load and initialize all storage drivers +6. `InitTaskManager()` - Start background tasks +7. `Start()` - Start HTTP/HTTPS/WebDAV/FTP/SFTP servers + +## Common Patterns + +### Error Handling + +Use custom errors from `internal/errs/`: +- `errs.NotImplement` - Feature not implemented +- `errs.ObjectNotFound` - File/folder not found +- `errs.NotFolder` - Path is not a directory +- `errs.StorageNotInit` - Storage driver not initialized + +**Link Expiry Detection**: +```go +// Checks error string for keywords: "expired", "invalid signature", "token expired" +// Also checks HTTP status: 401, 403, 410, 500 +if stream.IsLinkExpiredError(err) { + // Refresh link +} +``` + +### Upload and OSS Callback Validation + +Drivers that upload via Aliyun OSS (e.g., `115`, `115_open`) use a callback mechanism: after OSS stores the file, it POSTs to the storage provider's callback URL. The provider returns a JSON response indicating whether the file was registered. + +**Critical**: Always capture and validate the callback response: +```go +var bodyBytes []byte +_, err = bucket.CompleteMultipartUpload(imur, parts, + oss.Callback(base64.StdEncoding.EncodeToString([]byte(callback))), + oss.CallbackVar(base64.StdEncoding.EncodeToString([]byte(callbackVar))), + oss.CallbackResult(&bodyBytes), // ← MUST capture this +) +// Check both OSS error AND callback response +if err != nil { return err } +// Parse bodyBytes to verify {"state": true} +``` + +Without `oss.CallbackResult`, the upload appears successful (OSS returns 200) but the file is never registered on the provider's side. The local cache shows the file temporarily, but it vanishes on refresh. + +**`Put` vs `PutResult`**: Drivers can implement either `driver.Put` (returns `error`) or `driver.PutResult` (returns `model.Obj, error`). When `Put` returns nil, `op.Put` creates a temporary object in the directory cache. When `PutResult` returns an actual object, that object is used in the cache instead. + +### HybridCache replaces the old `MaxBufferLimit` truncation + +Historically `CacheFullAndHash` → `CacheFullAndWriter` → `cache()` capped buffering at `MaxBufferLimit` (~48MB), so non-seekable streams larger than that produced a truncated SHA1 (only the first 48MB hashed) and providers like 115 rejected the upload. The 115_open driver carried a `utils.CreateTempFile` workaround. + +**Current (2026-05-17)**: upstream PR #2460 unified caching via `internal/mem.HybridCache` — three tiers (heap memory → `LinearMemory` → temp file fallback) with 16MB blocks (`MaxBlockLimit`). `CacheFullAndWriter` now writes the *entire* stream regardless of size, so the truncation bug is gone for **all drivers**, not just 115_open. The 115_open workaround was removed; the standard `stream.CacheFullAndHash` path is back. + +**Unaffected paths**: +- `SeekableStream` (copy tasks): uses `RangeRead` directly, never goes through `cache()` +- Form uploads: `c.FormFile()` already stores to a temp file, so `GetFile()` returns non-nil and `CacheFullAndWriter` reads the entire file + +### Pass 2 prefetch on `hybridSectionReader` + +For sequential multipart uploads (115_open, baidu_netdisk, aliyundrive_open, etc.), `hybridSectionReader.GetSectionReader` launches a background goroutine to pre-read the next chunk while the caller uploads the current one. The next call picks up the prefetched block; mismatched offsets fall back to synchronous read. Prefetch is clamped to remaining file size, drains on `DiscardSection`, and surfaces errors on the next `GetSectionReader`. Cleanup is registered via `file.Add` so the goroutine is drained before `HybridCache` is freed. + +Drivers that already use `errgroup.Lifecycle` (e.g. 123/upload.go) with `Before` (List/GetSectionReader) and `Do` (upload) get overlap from the lifecycle pattern itself; the section-reader prefetch helps drivers that loop sequentially without Before/Do split (e.g. 115_open, which requires `oss.Sequential()`). + +### Merge-task ObjectNotFound tolerance + +Merge-mode `FileTransferTask.RunWithNextTaskCallback` calls `op.List(dst)` to build the `existedObjs` skip set. A non-existent dst must be treated as "empty" rather than fatal (otherwise resuming an interrupted merge to a fresh dst fails immediately). The logic is extracted into `existingDstFilesFn` in `internal/fs/copy_move.go`: + +```go +dstObjs, err := listDst(ctx, dstPath) +if err != nil && !errors.Is(err, errs.ObjectNotFound) { + return nil, errors.WithMessagef(err, "failed list dst [%s] objs", dstPath) +} +// non-existent dst → empty map, merge proceeds and creates dst on demand +``` + +A previous BFS-style "precreate one level of subdirectories" optimization was removed (2026-05-17, commit `6b3ce577`); directory creation now happens on demand via `op.Put`'s internal `MakeDir(parent)` and `op.MakeDir`'s recursive parent walk. The top-level dst `MakeDir` at `copy_move.go:198` is kept for early failure detection. 12 unit tests on `existingDstFilesFn` (including raw + wrapped `ObjectNotFound`) lock in the contract. + +### Saving Driver State + +When updating tokens or credentials: +```go +d.AccessToken = newToken +op.MustSaveDriverStorage(d) // Persists to database +``` + +### Rate Limiting + +Use `rate.Limiter` for API rate limits: +```go +type YourDriver struct { + limiter *rate.Limiter +} + +func (d *YourDriver) Init(ctx context.Context) error { + d.limiter = rate.NewLimiter(rate.Every(time.Second), 1) // 1 req/sec +} + +func (d *YourDriver) List(...) { + d.limiter.Wait(ctx) + // Make API call +} +``` + +### Context Cancellation + +Always respect context cancellation in long operations: +```go +select { +case <-ctx.Done(): + return nil, ctx.Err() +default: + // Continue operation +} +``` + +## Important Conventions + +**Naming**: +- Drivers: lowercase with underscores (e.g., `baidu_netdisk`, `aliyundrive_open`) +- Packages: lowercase (e.g., `internal/op`) +- Interfaces: PascalCase with suffix (e.g., `Reader`, `Writer`) + +**Driver Configuration Fields**: +- Use `driver.RootPath` or `driver.RootID` for root folder +- Add `omitempty` to optional JSON fields +- Use descriptive help text in struct tags + +**Retries and Timeouts**: +- Use `github.com/avast/retry-go` for retry logic +- Set reasonable timeouts on HTTP clients (default 30s in `base.RestyClient`) +- For unstable APIs, implement exponential backoff + +**Logging**: +- Use `logrus` via `log` package +- Levels: `log.Debugf`, `log.Infof`, `log.Warnf`, `log.Errorf` +- Include driver name in logs: `log.Infof("[driver_name] message")` + +## Project Context + +OpenList is a community-driven fork of AList, focused on: +- Long-term governance and trust +- Support for 70+ cloud storage providers +- Web UI for file management +- Multi-protocol support (HTTP, WebDAV, FTP, SFTP, S3) +- Offline downloads (Aria2, Transmission) +- Full-text search +- Archive extraction + +**License**: AGPL-3.0 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index c8344eb6c..6d4a6ef27 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -106,7 +106,7 @@ Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community -standards, including sustained inappropriate behavior, harassment of an +standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within @@ -116,7 +116,7 @@ the community. This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. +. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). @@ -124,5 +124,5 @@ enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at -https://www.contributor-covenant.org/faq. Translations are available at -https://www.contributor-covenant.org/translations. +. Translations are available at +. diff --git a/COMPATIBILITY_REPORT.md b/COMPATIBILITY_REPORT.md new file mode 100644 index 000000000..0b8be3b52 --- /dev/null +++ b/COMPATIBILITY_REPORT.md @@ -0,0 +1,204 @@ +# Rebase兼容性分析报告 + +## 提交概览 +共引入 **21个commits**,主要涉及以下模块: + +### 核心功能改动 + +#### 1. **链接刷新机制** (`internal/stream/util.go`) +**Commits**: +- `4c33ffa4` feat(link): add link refresh capability for expired download links +- `f38fe180` fix(stream): 修复链接过期检测逻辑,避免将上下文取消视为链接过期 +- `7cf362c6` fix(stream): 更新过期链接检查逻辑,支持所有4xx客户端错误 +- `03fbaf1c` refactor(stream): 移除过时的链接刷新逻辑,添加自愈读取器以处理0字节读取 + +**核心代码**: +```go +// 新增常量 +MAX_LINK_REFRESH_COUNT = 50 // 链接最大刷新次数 +MAX_RANGE_READ_RETRY_COUNT = 5 // RangeRead重试次数(从3提升到5) + +// 新增函数 +IsLinkExpiredError(err error) bool // 判断是否为链接过期错误 + +// 新增结构 +RefreshableRangeReader struct { + link *model.Link + size int64 + innerReader model.RangeReaderIF + mu sync.Mutex + refreshCount int // 防止无限循环 +} + +selfHealingReadCloser struct { + // 检测0字节读取,自动刷新链接 +} +``` + +**功能说明**: +1. **链接过期检测**: 识别多种云盘的过期错误(expired, token expired, access denied, 4xx状态码等) +2. **自动刷新**: 检测到过期时自动调用Refresher获取新链接,最多刷新50次 +3. **自愈机制**: 处理某些云盘返回200但内容为空的情况(0字节读取检测) +4. **并发安全**: 使用sync.Mutex保护共享状态 +5. **Context隔离**: 刷新时使用WithoutCancel避免用户取消操作影响刷新 + +**潜在风险**: +- ✅ Context.WithoutCancel需要Go 1.21+ +- ✅ 并发场景下的锁竞争 +- ✅ refreshCount可能在某些场景下不递增导致无限循环 + +--- + +#### 2. **目录预创建优化** (`internal/fs/copy_move.go`) +**Commit**: `ce0da112` fix(copy_move): 将预创建子目录的深度从2级调整为1级 + +**核心代码**: +```go +func (t *FileTransferTask) preCreateDirectoryTree(objs []model.Obj, dstBasePath string, maxDepth int) error { + // 第一轮:创建直接子目录 + for _, obj := range objs { + if obj.IsDir() { + subdirPath := stdpath.Join(dstBasePath, obj.GetName()) + op.MakeDir(t.Ctx(), t.DstStorage, subdirPath) + subdirs = append(subdirs, obj) + } + } + + // 停止递归条件 + if maxDepth <= 0 { + return nil + } + + // 第二轮:递归创建嵌套目录 + for _, subdir := range subdirs { + subObjs := op.List(...) + preCreateDirectoryTree(subObjs, subdirDstPath, maxDepth-1) + } +} +``` + +**功能说明**: +1. **深度控制**: 默认maxDepth=1,只预创建2级目录(当前+子级) +2. **防止深度递归**: 避免在大型项目中递归过深导致栈溢出或性能问题 +3. **错误容忍**: MakeDir失败时继续处理其他目录 +4. **Context感知**: 每次循环检查ctx.Err()支持取消操作 + +**潜在风险**: +- ✅ op.MakeDir和op.List调用需要存储初始化 +- ✅ 大量目录时的性能问题 +- ✅ Context取消时的资源清理 + +--- + +#### 3. **网络优化** (`drivers/`, `internal/net/`) +**Commits**: +- `b9dafa65` feat(network): 增加对慢速网络的支持,调整超时和重试机制 +- `bce47884` fix(driver): 增加夸克分片大小调整逻辑,支持重试机制 +- `0b8471f6` feat(quark_open): 添加速率限制和重试逻辑 + +**功能说明**: +1. 提升RangeRead重试次数: 3 → 5 +2. 调整网络超时参数 +3. 添加分片上传重试逻辑 + +--- + +#### 4. **驱动修复** +**Commits**: +- `da2812c0` fix(google_drive): 更新Put方法以支持可重复读取流和不可重复读取流的MD5校验 +- `5a6bad90` feat(google_drive): 添加文件夹创建的锁机制和重试逻辑 +- `a54b2388` feat(google_drive): 添加处理重复文件名的功能 +- `9ef22ec9` fix(driver): fix file copy failure to 123pan due to incorrect etag +- `0ead87ef` fix(alias): update storage retrieval method in listRoot function +- `311f6246` fix: 修复500 panic和NaN问题 + +--- + +## 兼容性评估 + +### ✅ 编译兼容性 +- 构建成功,无语法错误 +- 依赖版本无冲突 + +### ✅ API兼容性 +- 新增函数不破坏现有接口 +- RefreshableRangeReader实现model.RangeReaderIF接口 +- 向后兼容旧代码 + +### ⚠️ 运行时兼容性 +**需要验证的场景**: +1. **并发安全**: RefreshableRangeReader的并发读取 +2. **资源泄漏**: Context取消时goroutine是否正确退出 +3. **边界条件**: + - refreshCount达到50次的行为 + - 0字节读取检测的准确性 + - maxDepth=0时的目录创建 +4. **错误处理**: + - nil Refresher时的处理 + - 链接刷新失败时的回退机制 +5. **性能**: + - 大文件下载时的刷新开销 + - 深层目录结构的预创建性能 + +--- + +## 测试需求 + +### 必须测试的场景 + +#### Stream包测试 +1. **IsLinkExpiredError准确性** + - 各种云盘的过期错误格式 + - Context取消不应判断为过期 + - HTTP 4xx/5xx的区分 + +2. **RefreshableRangeReader可靠性** + - 正常读取流程 + - 自动刷新触发和成功 + - 达到最大刷新次数 + - 并发读取安全性 + - Context取消的正确处理 + +3. **selfHealingReadCloser** + - 0字节读取检测 + - 刷新重试机制 + - 资源正确关闭 + +#### FS包测试 +1. **preCreateDirectoryTree** + - 深度控制正确性(0, 1, 2级) + - 大量目录的性能 + - Context取消的响应 + - 错误容忍性 + +--- + +## 风险等级: **中等** + +**原因**: +- ✅ 新功能设计合理,有明确的边界和错误处理 +- ⚠️ 并发场景需要充分测试 +- ⚠️ 链接刷新逻辑复杂,需要验证各种边界情况 +- ⚠️ 依赖op包的函数需要正确的初始化 + +--- + +## 推送建议: **通过测试后可推送** + +**前置条件**: +1. 完成全面的单元测试(见下方测试代码) +2. 验证并发安全性 +3. 确认Context取消不会导致资源泄漏 +4. 性能测试通过(大文件、深层目录) + +**建议测试命令**: +```bash +# 单元测试 +go test ./internal/stream ./internal/fs -v -count=1 -race + +# 压力测试 +go test ./internal/stream -run Stress -v -count=10 + +# 完整测试套件 +go test ./... -short -count=1 +``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6064b03cb..57317c629 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,7 +7,7 @@ Prerequisites: - [git](https://git-scm.com) -- [Go 1.24+](https://golang.org/doc/install) +- [Go](https://golang.org/doc/install) version declared in [`go.mod`](./go.mod) - [gcc](https://gcc.gnu.org/) - [nodejs](https://nodejs.org/) @@ -16,8 +16,8 @@ Prerequisites: Fork and clone `OpenList` and `OpenList-Frontend` anywhere: ```shell -$ git clone https://github.com//OpenList.git -$ git clone --recurse-submodules https://github.com//OpenList-Frontend.git +git clone https://github.com//OpenList.git +git clone --recurse-submodules https://github.com//OpenList-Frontend.git ``` ## Creating a branch @@ -25,7 +25,7 @@ $ git clone --recurse-submodules https://github.com//OpenList-Fro Create a new branch from the `main` branch, with an appropriate name. ```shell -$ git checkout -b +git checkout -b ``` ## Preview your change @@ -33,26 +33,36 @@ $ git checkout -b ### backend ```shell -$ go run main.go +go run main.go ``` ### frontend ```shell -$ pnpm dev +pnpm dev ``` ## Add a new driver Copy `drivers/template` folder and rename it, and follow the comments in it. +## Community and policies + +By contributing, you agree to follow the repository's code of conduct and license terms. + +- Code of conduct: [CODE_OF_CONDUCT.md](./CODE_OF_CONDUCT.md) +- License: [LICENSE](./LICENSE) +- Security issues: please report privately according to [SECURITY.md](./SECURITY.md) + +If your contribution includes substantial AI-assisted content, disclose the tools used and the scope of assistance in the pull request. + ## Create a commit Commit messages should be well formatted, and to make that "standardized". Submit your pull request. For PR titles, follow [Conventional Commits](https://www.conventionalcommits.org). -https://github.com/OpenListTeam/OpenList/issues/376 + It's suggested to sign your commits. See: [How to sign commits](https://docs.github.com/en/authentication/managing-commit-signature-verification/signing-commits) @@ -72,6 +82,19 @@ At least 1 approving review is required by reviewers with write access. You can (Optional) After your pull request is merged, you can delete your branch. +## AI Disclosure + +If your pull request includes substantial AI-assisted content, disclose it in the PR description. + +Please include: + +- Tools used, such as ChatGPT, GitHub Copilot, Claude, Cursor, or other AI tools. +- Usage scope, such as code generation, refactoring, documentation, tests, translation, or review assistance. +- Confirmation that you have reviewed and validated all AI-assisted content before submission. +- Confirmation that the submitted content complies with this repository's license and contribution policies. + +Minor AI assistance, such as typo fixes, autocomplete, formatting suggestions, or wording polish, does not need to be disclosed. + --- Thank you for your contribution! Let's make OpenList better together! diff --git a/JOURNAL.md b/JOURNAL.md new file mode 100644 index 000000000..40e75604e --- /dev/null +++ b/JOURNAL.md @@ -0,0 +1,486 @@ +# Development Journal + +Chronological log of all changes in this fork, from earliest to latest. + +--- + +## 2026-04-25 — Initial Feature Batch (rebased onto up/main) + +### `17be63fb` feat(stream): link refresh, self-healing reader, seekable prefetch, and upload hash rework +- Added `RefreshableRangeReader`: wraps `RangeReader` with automatic link refresh on expiry (max 50 attempts) +- Added `selfHealingReadCloser`: detects 0-byte reads, connection resets, and `io.ErrUnexpectedEOF`; reconnects from current offset transparently +- Added `IsLinkExpiredError()`: checks error strings + HTTP 4xx status codes for link expiry +- Added 2-window async prefetch in `directSectionReader`: while uploading chunk N, prefetch chunk N+1 via goroutine +- Reworked `StreamHashFile` to use `RangeRead` for `SeekableStream` (never consumes `Reader`) +- Added `ReadFullWithRangeRead` with retry (max 5 attempts, 1-5s backoff) +- Files: `internal/stream/util.go`, `internal/stream/stream.go` + +### `1db9136f` feat(google_drive): duplicate filename handling, folder lock, retry, and MD5 checksum +- Added `mkdirLocks` (`sync.Map`) to prevent concurrent creation of duplicate folders +- Added existence check with retry before folder creation +- Added 500ms consistency delay after folder creation +- Added MD5 hash computation for upload integrity (via `StreamHashFile`) +- Added `chunkUpload` with retry and `RangeRead`-based streaming +- Files: `drivers/google_drive/driver.go`, `drivers/google_drive/util.go` + +### `f9bc1567` feat(115_open): permanent delete, proxy_range, offline task fixes, and error handling +- Added `RemoveWay` config option: "trash" (default) or "delete" (permanent) +- Implemented `removePermanently`: deletes from recycle bin after trash +- Added `findRecycleBinEntry` with paginated recycle bin search +- Added `matchRecycleBinEntry` with multi-strategy matching (ID → SHA1 → name+size) +- Added `findRecycleBinEntryWithRetry` (4 attempts, 300ms backoff) for eventual consistency +- Added `FlexString` CID handling (numeric/string JSON interop) +- Added `proxy_range` option exposure +- Files: `drivers/115_open/driver.go`, `drivers/115_open/meta.go`, `drivers/115_open/upload.go`, `drivers/115_open/driver_test.go` + +### `f1493fc9` feat(offline_download): multi-page task retrieval and task limit wait mechanism +- 115 `OfflineList` now paginates through all pages (was only page 1) +- Added task limit wait mechanism in offline download client +- Moved 115 offline task cleanup from `Run()` to `Update()` so it runs even if transfer fails +- Files: `internal/offline_download/115_open/client.go`, `internal/offline_download/tool/download.go` + +### `35130094` feat(drivers): baidu streaming upload, quark rate-limit/retry, 123pan etag fix +- **Baidu Netdisk**: extracted upload logic to `upload.go`, full streaming upload support +- **Quark Open**: added rate limiter, retry with chunk size adjustment on 413 error +- **123 Open**: fixed copy failure due to incorrect etag, added SHA1 rapid upload +- Normalized `StreamHashFile` progress weight to 100 across all drivers +- Files: `drivers/baidu_netdisk/upload.go`, `drivers/quark_open/driver.go`, `drivers/123_open/driver.go` + +### `7787ddbc` fix(core): copy_move depth, alias storage retrieval, sftp symlink, 500 panic +- Fixed `preCreateDirectoryTree` depth from 2 to 1 to avoid deep recursion +- Fixed alias storage retrieval method in `listRoot` +- Fixed SFTP symlink handling +- Fixed 500 panic and NaN issues +- Files: `internal/fs/copy_move.go`, `drivers/alias/util.go`, `drivers/sftp/types.go` + +### `4e735df0` feat(frontend): dynamic frontend fetching, CI upgrades, and build infrastructure +- Added `internal/frontend/fetcher.go`: auto-download frontend dist from GitHub releases +- Added `internal/frontend/watcher.go`: periodic check (30min) for new versions +- Added `FrontendRepoDefault` ldflags variable for build-time frontend repo injection +- CI: action version upgrades, frontend caching, version matrix builds +- `build.sh`: frontend repo configurable via `FRONTEND_REPO` env var +- Files: `internal/frontend/`, `.github/workflows/`, `build.sh`, `internal/conf/`, `server/static/` + +### `0bf12c5c` chore: add project docs and update dependencies (115-sdk-go fork) +- Added `CLAUDE.md` with comprehensive project guidance +- Added `COMPATIBILITY_REPORT.md` for 115-sdk-go fork analysis +- Updated 115-sdk-go dependency to `v0.2.5` (FlexString CID support) +- Files: `CLAUDE.md`, `COMPATIBILITY_REPORT.md`, `go.mod`, `go.sum` + +--- + +## 2026-05-08 — Rebase onto upstream + Code Review Fixes + +### Rebase onto `up/main` @ `c7c0cfae` +- `feat/dynamic-frontend` cleanly rebased (8 commits, 0 conflicts) +- Upstream additions included: path validation, SplitSeq perf, ObjectAlreadyExists check, qBittorrent login fix, custom share IDs, Getter interfaces (webdav/s3/115_open), about page logo fix + +### `0369bc0a` fix: bounded auth retry, mkdirLocks cleanup, EOF handling, tar size limit, dist swap lock, Go 1.26 vet +Code review identified 16 issues (2 CRITICAL, 5 HIGH). Fixed 6: +- **Google Drive Put 401 recursion** (CRITICAL): replaced infinite recursive `d.Put()` call with `putWithRetry()` + `maxPutAuthRetries=2` +- **mkdirLocks memory leak** (HIGH): added `defer mkdirLocks.Delete(lockKey)` after unlock +- **selfHealingReadCloser EOF** (HIGH): removed `io.EOF` from reconnect trigger, kept only `io.ErrUnexpectedEOF` +- **Frontend tar extraction** (HIGH): added `maxExtractFileSize=500MB` + `io.LimitReader` + `hdr.Size` check +- **Frontend dist swap TOCTOU** (HIGH): added `distSwapMu` mutex around rename window +- **RefreshableRangeReader concurrency** (CRITICAL): documented that local `reader` copy is safe after `innerReader` replacement +- **Go 1.26 vet**: fixed `fmt.Errorf` non-constant format strings across 8 drivers/packages +- Added tests: `drivers/google_drive/driver_test.go`, stream EOF tests, frontend oversized file test +- Files: 14 files changed, +222/-22 + +--- + +## 2026-05-09 — OSS Upload Fix + Hash Prefetch + +### `214881b3` fix(115_open): handle OSS upload timeout and PartAlreadyExist retry +- **Root cause**: `ResponseHeaderTimeout=15s` in shared `NewHttpClient()` was too short for uploading 20MB OSS parts. Timeout caused part to be uploaded but unconfirmed; retry hit `PartAlreadyExist` (409). +- **Fix 1**: Created `NewOSSUploadHttpClient()` with `ResponseHeaderTimeout=5min` dedicated to OSS uploads +- **Fix 2**: In `multpartUpload` retry, detect `PartAlreadyExist` → call `ListUploadedParts` to recover the part's ETag → treat as success +- Added `isPartAlreadyExistError()` helper +- Tests: `upload_test.go` (PartAlreadyExist detection), `oss_test.go` (upload client timeout) +- Files: `drivers/115_open/upload.go`, `internal/net/oss.go` + +### `94591821` perf(stream): add double-buffer prefetch to StreamHashFile for SeekableStream +- **Before**: hash calculation read 10MB chunks sequentially (network idle during hash computation) +- **After**: while hashing chunk N, goroutine prefetches chunk N+1 via `ReadFullWithRangeRead` +- Extracted `streamHashSeekableWithPrefetch()` with double-buffering pattern +- Hash values identical — no change to upload flow or rapid-upload logic +- `FileStream` path unchanged (sequential read) +- Test: `TestStreamHashFile_SeekablePrefetchProducesSameHash` +- Files: `internal/stream/util.go`, `internal/stream/util_test.go` + +--- + +## 2026-05-10 — 115 Rate Limiting + Concurrent Token Refresh Fix + +### 问题现象 + +复制任务(`/scnet/` → `/storage/my_115/`)全部失败,错误 `code: 0, message:` 和 `code: 40100000, message: 参数错误!`。目录确实存在,单任务正常,多 worker 并发时全部报错。 + +### 根因 1:Put 方法多个 SDK 调用未走限流器 + +`d1b72178` fix(115_open): rate-limit every SDK call in Put method + +- **发现**:`Put()` 入口只调一次 `WaitLimit`,后续 3-4 个 SDK 请求(`UploadInit` ×3 + `UploadGetToken`)直接发出,不经过限流器 +- **影响**:10 个 copy worker 并发时,不受限的 Put 请求和其他走限流器的 List/MakeDir 请求一起打到 115 API,瞬时 QPS 超过 115 的限制,API 返回 `state:false, code:0`(空错误)拒绝所有请求 +- **修复**:移除 Put 入口的单次 `WaitLimit`,在每个 SDK 调用(`UploadInit`、`UploadGetToken`)前单独加 `WaitLimit` +- **测试**:`TestPutRateLimitsEverySDKCall` — 设置 10 req/s 限流器,验证 3 个 UploadInit 调用之间有 >=70ms 间隔;`TestPutRateLimitsPreHashPath` — 验证秒传成功路径 +- Files: `drivers/115_open/driver.go`, `drivers/115_open/driver_test.go` + +### 根因 2:SDK RefreshToken 无并发保护 + +`823f46ba` fix(deps): bump 115-sdk-go to v0.2.6 for concurrent refresh fix + +- **发现**:日志显示 `40140117 refresh frequently` 和 `40140120 refresh token error` 从 3 月 25 日就开始出现。115 的 refresh token 是一次性的——token 过期后多个 goroutine 同时调 `authRequest`,同时检测到 401,同时调 `RefreshToken`。第一个成功消耗了旧 RT,后续的全部失败(RT 已作废),token 被损坏或清空 +- **时间线**(`my_115` 实例): + - 07:06:19 — 存储加载成功 + - 07:28:34 — copy workers 从 3 改成 10 + - 09:11:29 — 最后一条成功的 `[115] GetFiles` 日志 + - 09:11-09:45 — 35 分钟无 `[115]` 日志(全是文件上传,走 UploadInit 不产生 `[115]` 日志) + - 09:46:01 — 首次 `40100000 参数错误`(token 已失效/清空) + - 从 07:06 到 09:46 正好 ~2h40m,115 access token TTL 约 2h +- **修复**(SDK `v0.2.6`):`authRequest` 中加 `refreshMu sync.Mutex` + double-check pattern。发请求前锁内读取 `usedToken`,遇 401 后锁内比对 `c.accessToken == usedToken`,若已被别的 goroutine 刷新过则跳过,未刷新才调 `RefreshToken` +- **测试**:`TestConcurrentAuthRequestRefreshesOnlyOnce` — 10 个并发 goroutine 同时打过期 token,断言 `RefreshToken` 只被调用 1 次,全部 goroutine 成功。`count=3` 稳定通过 +- Files: SDK `client.go`, `request.go`, `request_test.go`; OP `go.mod`, `go.sum` + +--- + +## 2026-05-11 — 115 GetFolderInfoByPath 空数据处理 + +### 问题现象 + +复制任务预建目标子目录时报错:`failed to get obj: json: cannot unmarshal array into Go value of type sdk.GetFolderInfoResp`。只有**不存在**的子目录报错,已存在的目录正常。 + +### 根因 + +115 Open API 的 `GetFolderInfoByPath` 在路径不存在时返回 `{state:true, data:[]}` 而不是正常的错误码。SDK 的 `authRequest` 直接把 `[]` 反序列化到 `GetFolderInfoResp`(struct)→ `json.UnmarshalTypeError`。该错误不是 `errs.ObjectNotFound`,导致 `MakeDir` 在 `op/fs.go:350` 当作未知错误抛出,而不是正常进入"目录不存在→创建"流程。 + +### 修复 + +**SDK v0.2.8**(`cf4f508` fix: return ErrDataEmpty when API responds with empty data): +- 新增 `ErrDataEmpty` sentinel error +- `authRequest` 在 `extractData` 模式下:`data` 为 `null`/空 → 直接返回 `ErrDataEmpty`;`data` 为 `[]` 反序列化到 struct 失败 → 也返回 `ErrDataEmpty`(反序列化到 slice 类型则正常通过,不影响 `DelFile` 等返回空数组的 API) + +**Driver 层**: +- `Open115.Get()` 捕获 `sdk.ErrDataEmpty` → 转为 `errs.ObjectNotFound` +- `MakeDir` 的 `errs.IsObjectNotFound` 检测通过 → 正常创建目录 + +### 测试 + +- SDK: `TestAuthRequestReturnsErrDataEmptyForEmptyArray`、`TestAuthRequestReturnsErrDataEmptyForNull`、`TestAuthRequestSucceedsForValidObject` +- Driver: `TestGetReturnsObjectNotFoundForEmptyData`、`TestGetReturnsObjForExistingFolder` +- 既有 27 个测试全部通过,含 `TestOpen115RemoveDeleteReturnsErrorWhenRecycleEntryMissing`(验证 `DelFile` 的 `data:[]` 不受影响) + +### 教训:不要 force-push tag + +Go module proxy 会缓存 tag 第一次发布时的内容。Force-push 更新 tag 后,`go.sum` 中记录的旧 hash 与新内容不匹配,触发 `checksum mismatch` 安全错误。正确做法是打新版号(v0.2.7 → v0.2.8)。 + +- Files: SDK `error.go`, `request.go`, `request_test.go`; OP `drivers/115_open/driver.go`, `drivers/115_open/driver_test.go`, `go.mod`, `go.sum` + +--- + +## 已完成调查:FileStream.cache truncated stream(2026-04-07) + +### 现象 + +- `failed to read all data: (expect =50331648, actual =41592644) unexpected EOF` +- 调用链:`FsStream -> fs.PutDirectly -> op.Put -> FileStream.cache` +- 115_open 与 google_drive 均出现 + +### 根因 + +上游请求体提前结束(truncated stream),`FileStream.cache()` 的 `io.ReadFull` 严格检测并报错。`50331648 = 48MiB`(MaxBufferLimit 窗口),`41592644 ≈ 39.66MiB`(实际收到的字节数)。actual 值不固定,排除驱动逻辑在固定偏移崩溃的可能。 + +`3b2f9d55` 将"超限时落盘"改为"裁剪到 MaxBufferLimit",使错误暴露更早,但非根因。 + +### 处置 + +P0:在上传入口增加"声明长度 vs 实际接收长度"日志;排查反向代理超时。 + +--- + +## 2026-05-13 — 115 Open 上传静默失败修复(双 bug) + +### 问题现象 + +前端上传文件显示成功(进度 100%,PUT `/api/fs/put` 返回 HTTP 200),但刷新页面后文件消失。日志无任何错误。 + +### 日志分析 + +从 330K 行日志中定位到两次上传(12:53:31 和 13:10:05),均耗时 ~2 分钟(实际数据传输),返回 200。上传前后 `arc` 目录文件数始终为 11,文件从未出现在 115 的文件列表中。 + +对比历史上传: + +- **2026-04-06**:3 次上传(10s/22s/25s)全部成功,`PUB` 目录从 2 增至 5 文件 +- **2026-04-14**:第 1 次上传失败(5→5 文件),第 2 次成功(5→6 文件) +- **2026-05-13**:两次上传均失败(11→11 文件) + +### Bug 1:OSS 回调未校验(静默失败) + +对比非 Open 版 115 驱动(`drivers/115/util.go`)与 Open 版(`drivers/115_open/upload.go`): + +**非 Open 版**(正确): + +```go +var bodyBytes []byte +bucket.CompleteMultipartUpload(imur, parts, + oss.Callback(...), + oss.CallbackResult(&bodyBytes), // ← 捕获回调响应 +) +var uploadResult UploadResult +json.Unmarshal(bodyBytes, &uploadResult) +return &uploadResult, uploadResult.Err(...) // ← 校验 state 字段 +``` + +**Open 版**(有 bug): + +```go +// callbackRespBytes := make([]byte, 1024) ← 注释掉了! +_, err = bucket.CompleteMultipartUpload(imur, parts, + oss.Callback(...), + // oss.CallbackResult(&callbackRespBytes), ← 注释掉了! +) +if err != nil { return err } // ← 只检查 OSS 层错误 +return nil // ← 115 回调失败被忽略 +``` + +OSS 上传流程:客户端上传分片到 OSS → `CompleteMultipartUpload` → OSS 调用 115 的回调 URL → 115 返回 `{"state": true/false}`。OSS 只要回调 URL 返回 HTTP 200 就认为成功(`err == nil`),但 115 可能在 body 中返回 `{"state": false}` 表示文件注册失败。Open 版驱动没有捕获回调 body,所以 115 拒绝注册文件时完全静默。 + +`op.Put`(`internal/op/fs.go:714-731`)在 `err == nil` 时把文件临时加入目录缓存 → 前端显示成功 → 用户刷新后缓存过期,文件消失。 + +**修复**:新增 `UploadCallbackResult` + `checkUploadCallback()`,`singleUpload` 和 `multpartUpload` 均添加 `oss.CallbackResult(&bodyBytes)` 捕获并校验回调。 + +### Bug 2:Stream 模式大文件 SHA1 截断(根因) + +加了回调校验后看到真正的错误:`code=10002, message=校验文件失败,请重新上传。` + +关键线索:Form 模式(`PUT /api/fs/form`)和 Copy task 均成功,只有 Stream 模式(`PUT /api/fs/put`)失败。 + +**根因**:`CacheFullAndHash` → `CacheFullAndWriter` → `cache()` 将缓存上限截断为 `MaxBufferLimit`(~48MB): + +```go +func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { + if maxCacheSize > int64(conf.MaxBufferLimit) { + maxCacheSize = int64(conf.MaxBufferLimit) // ← 截断! + } + // ...只读取前 48MB 到 peekBuff +} +``` + +870MB 文件:SHA1 只算了前 48MB → `UploadInit` 带错误的 SHA1 → 完整 870MB 上传到 OSS → 115 校验:SHA1(870MB) ≠ SHA1(前 48MB) → 拒绝。 + +**为什么 Form 模式不受影响**:`c.FormFile()` 已将文件存入 `*os.File` 临时文件,`GetFile()` 返回非 nil,`CacheFullAndWriter` 走第一个分支直接读整个文件算 hash。 + +**为什么 Copy task 不受影响**:走 `SeekableStream` 路径(`isSeekable == true`),用 `StreamHashFile` → `RangeRead` 按需读取,不经过 `cache()`。 + +**修复**:`drivers/115_open/driver.go` Put 方法,对非 seekable 的 `FileStream`,当 `GetFile() == nil` 时先用 `utils.CreateTempFile` 将完整流写入临时文件,设置 `fs.Reader = tmpF`,再用 `StreamHashFile` 从临时文件计算正确的 SHA1。 + +### 上传数据流(修复后) + +``` +Stream 上传 (FileStream, !isSeekable): + HTTP Body → CreateTempFile(磁盘) → StreamHashFile(临时文件) → UploadInit → 分片上传(临时文件) → 清理 + +Copy task (SeekableStream, isSeekable): [不变] + Pass 1: RangeRead(源存储) → StreamHashFile → SHA1 + Pass 2: RangeRead(源存储) → directSectionReader → 分片上传 +``` + +### 附带修复:前端下载阻塞启动 + +`server/static/static.go`:移除 `initStatic` 中 30 秒超时的同步前端下载(国内 GitHub 不通导致启动卡住),改为直接使用内嵌 dist,由 `Watcher` 后台异步下载 rolling 版本并热替换。 + +### 测试 + +新增 4 个回调校验测试,既有 30 个测试全部通过: + +- `TestCheckUploadCallbackSuccess` — state=true 正常通过 +- `TestCheckUploadCallbackStateFalse` — state=false 返回含 code/message 的错误 +- `TestCheckUploadCallbackEmptyBody` — 空响应返回错误 +- `TestCheckUploadCallbackInvalidJSON` — 非法 JSON 返回错误 + +### `cache()` 截断是否影响其他驱动 + +`cache()` 的 `MaxBufferLimit` 截断是刻意设计(避免流式缓存占用过多内存),但 `CacheFullAndWriter` 名为"CacheFull"却内部调 `cache()` 导致大文件只缓存部分。任何依赖 `CacheFullAndHash` 对大文件(> MaxBufferLimit)计算 hash 的非 seekable 流驱动都可能受影响。本次修复仅在 115_open 驱动层绕过,未改动核心 stream 包。 + +- Files: `drivers/115_open/driver.go`, `drivers/115_open/driver_test.go`, `drivers/115_open/upload.go`, `server/static/static.go` + +--- + +## 2026-05-17 — Rebase onto upstream HybridCache + Pass 2 prefetch + BFS precreate 删除 + +### 背景 + +Upstream(OpenListTeam/OpenList)通过 PR #2460(`b6db83ed`)引入 `HybridCache`:三级缓存(普通堆内存 → `LinearMemory` → 文件落盘),用来统一 stream 缓存路径。这恰好替代了 2026-05-13 在 115_open 驱动层手写的 SHA1 截断绕过(`utils.CreateTempFile` workaround)。 + +### 行动 1:Rebase 16 个提交到 upstream HybridCache + +用 `git rebase --onto upstream/main` 把本地分支挪到 HybridCache 之上。冲突解决原则:**保留双方功能**。 + +关键冲突文件: + +- **`internal/stream/stream.go`** — 取 upstream HybridCache 版(`cache()` 内部三级分配),删掉本地原来的 `MaxBufferLimit` workaround +- **`internal/stream/util.go`** — 保留本地全部新增:`RefreshableRangeReader`、`selfHealingReadCloser`、`IsLinkExpiredError`、`ReadFullWithRangeRead`、`streamHashSeekableWithPrefetch`;丢弃本地的 `directSectionReader` + prefetch(在新结构上重写);upstream 的 `byteSectionReader` / `hybridSectionReader` 保留 +- **`drivers/115_open/driver.go`** Put 方法 — 删掉 `CreateTempFile` workaround,恢复标准 `CacheFullAndHash` 路径(HybridCache 已经保证完整缓存) +- **`drivers/baidu_netdisk/upload.go`** — 接口重命名:`StreamSectionReaderIF` → `StreamSectionReader`(upstream 名称) +- **`server/static/static.go`** — 保留本地动态前端 `reloadableFS`(rebase 期间一度被覆盖,已修正) + +Rebase 后 22 commits ahead of origin,force-pushed 至 `7995dbdb`。 + +### HybridCache vs 旧 workaround + +| 维度 | 旧 workaround | HybridCache | +|---|---|---| +| 修复范围 | 仅 115_open | stream 层,全 driver 受益 | +| 大文件存储 | 一律落盘 | 优先内存,紧张才落盘 | +| 块策略 | 单 buffer | 16MB 分块(`MaxBlockLimit`) | +| GC 友好度 | 一个大 `[]byte` | `LinearMemory` 避开 Go heap | +| Section reader 适配 | 整 buffer 切片 | 每块独立,对 `hybridSectionReader` 天然友好 | + +### 行动 2:Pass 2 prefetch on `hybridSectionReader`(commit `7995dbdb`) + +旧 `directSectionReader` 已被 upstream 的 `hybridSectionReader` 替代,但失去了"上传当前分片时异步预读下一分片"的优化(影响 115_open 这种顺序上传的 driver)。 + +TDD 红→绿→重构: + +- 在 `hybridSectionReader` 上加 `prefetchTask` 状态 + `schedulePrefetch` / `takePrefetched` / `waitPrefetch` +- 每次 `GetSectionReader` 返回后启动 goroutine 预读下一块到一个 pending block +- 下次 `GetSectionReader` 命中即拿、不命中走原路径 +- `DiscardSection` 排空 prefetch;prefetch 错误延迟到下次 `GetSectionReader` 暴露 +- `file.Add(closerFunc(waitPrefetch))` 注册清理,避免 `hc` 被释放时 prefetch 还在写 + +8 个测试覆盖:overlap 时序、顺序正确性、最后分片小于 partSize、Discard 配合、prefetch 错误传播、不超 EOF、长度 clamp、非 hybrid 路径回退。 + +- Files: `internal/stream/util.go`, `internal/stream/section_reader_prefetch_test.go` + +### 行动 3:删除 BFS precreate(commit `6b3ce577`) + +旧 `preCreateDirTreeFn` 提前 BFS 创建 dst 一层子目录。原始动机:避开"merge 任务 List 不存在的 dst 会致命"的 bug。 + +排查后发现该 bug 已由三层独立修复覆盖: + +1. **`copy_move.go:213` 容错判断**(PR #1898,你最初提交、KirCute 优化、Tron 合入):`if err != nil && !errors.Is(err, errs.ObjectNotFound)` 容忍 dst 不存在 +2. **`op.Put` 内置 MakeDir**(`internal/op/fs.go:682`):上传前自动建父目录 +3. **`op.MakeDir` 递归 + 缓存同步重试**(`internal/op/fs.go:354-364`):递归建父链 + 100ms 重试解决云盘缓存延迟 + +BFS precreate 在功能上已变成纯性能优化(提前批建 + 缓存预热),不再是 bug 修复的关键依赖。 + +**TDD 流程**: + +1. 抽出 `existingDstFilesFn(ctx, listDst, dstPath) → map[string]bool`,把 merge 模式构建 `existedObjs` 的逻辑独立为可注入纯函数 +2. 写 12 个单测:empty/files-only/dirs-only/mixed/raw-ObjectNotFound/wrapped-ObjectNotFound(#1898 回归保护)/其他 List 错误/ctx 取消前/ctx 取消中/重名/万级 dst 性能/单次 List 合理性 +3. 全绿后删除 `preCreateDirTreeFn`、`preCreateDirectoryTree` 方法、调用点、原有 11 个 `TestPreCreateDirTreeFn_*` 测试 +4. 保留顶层 dst MakeDir(`copy_move.go:198`)做早期失败检测 + +净 -95 行代码。`existingDstFilesFn` 的 ObjectNotFound 容错测试(`TestExistingDstFilesFn_DstDoesNotExist_RawError` / `_WrappedError`)是 #1898 的回归护栏,确保未来重构不把那个老 bug 漏回去。 + +- Files: `internal/fs/copy_move.go`, `internal/fs/copy_move_test.go` + +### 行动后状态 + +- `go build ./...` 干净 +- `internal/fs`、`internal/stream`、`internal/op`、`drivers/115_open`、`drivers/baidu_netdisk` 全部测试通过 +- 不相关失败:`internal/net/oss_test.go` 的 HTTPS proxy 测试(rebase 前就失败)、`pkg/aria2/rpc` 测试(需要本地 aria2 服务) +- Branch `feat/dynamic-frontend` HEAD 推至 `6b3ce577` + +--- + +## 2026-05-19 — Rebase onto `hybrid_cache` 包抽取 + prefetch 双 commit squash + +### 背景 + +上游 PR #2477(commit `9eae6258`)做了一次破坏性重构: + +- 把 `internal/mem`(`HybridCache` / `LinearMemory`)抽到独立的 `internal/hybrid_cache` 包(别名 `hcache`) +- 引入 `BackingStore` 抽象:`BufferStore`(内存)+ `FileStore`(磁盘),由 `HybridCache` 内部自动切换,外部代码不再分支 +- 配置字段重命名:`conf.CacheThreshold` → `conf.AutoMemoryLimit`(json `cache_threshold` → `auto_memory_limit`、env 同步) +- 删除 `pkg/buffer/bytes.go` / `pkg/buffer/file.go`,`pkg/buffer/buffer.go` + `type.go` 重塑 +- `internal/stream/util.go` 从 ~960 行精简至 298 行——上游只保留 `GetRangeReaderFromLink` / `CacheFullAndHash` / `NewStreamSectionReader` / `hybridSectionReader` 等核心 helpers,所有本地扩展需要重新落上去 + +同批次另有两个修复(`daad21ef` 139 driver `Connection` 头、`31b41f99` qBittorrent 5.2 204 登录修复),均不与本地冲突。 + +### 行动 1:Rebase 20 个本地 commit 到 `op/main` @ `31b41f99` + +`git tag pre-rebase-2026-05-19 HEAD` 后 `git rebase op/main`,命中两处冲突: + +**冲突点 1 — `internal/stream/stream.go`(`fccab338`)** + +上游把 `cache(maxCacheSize int64)` 重命名为 `ensureCache(size int64)`,函数体已转为新 `hcache.HybridCache`。处理:直接取上游版,本地 fccab338 patch 的语义已被上游覆盖。 + +**冲突点 2 — `internal/stream/util.go`(`fccab338`)** + +本地有独立的 `byteSectionReader` 分支(小文件走单独缓冲池)。上游新设计下 `BackingStore` 自动按文件大小选择 `BufferStore` / `FileStore`,`NewStreamSectionReader` 简化为: + +```go +if file.GetFile() != nil { return &cachedSectionReader{...} } +// 否则全部走 +return &hybridSectionReader{...} +``` + +处理:删除 `byteSectionReader` / `bytesRefReadSeeker` 类型,以及 `TestHybridSectionReader_NonPrefetchPath` 测试(断言"小文件不走 hybrid"已失效)。 + +**冲突点 3 — `internal/stream/util.go`(`7995dbdb` Pass 2 prefetch)** + +`hybridSectionReader` 结构体内的 `hc` 字段:本地 `*mem.HybridCache` vs 上游 `*hcache.HybridCache`。处理:保留本地新增的 `fileSize int64` 字段(Pass 2 prefetch 用于 EOF clamp),类型切到 `*hcache.HybridCache`。 + +**冲突点 4 — `internal/stream/section_reader_prefetch_test.go`** + +`conf.CacheThreshold` 三处引用全部改成 `conf.AutoMemoryLimit`。 + +其余 16 个 commit 干净回放(包括 `163e5c81` 的 util.go selfHealingReadCloser EOF 调整、`07410ad3` 的 Pass 1 hash prefetch、`4e91d93a` 的 `NewOSSUploadHttpClient` 等),上游 patch 已覆盖大部分 `mem.` → `hcache.` 翻译工作。 + +### 行动 2:Squash 两个 prefetch commit + +`588b7024 perf(stream): add double-buffer prefetch to StreamHashFile` 和 `a32544fa perf(stream): pass 2 prefetch on hybridSectionReader` 同属"上传 pipeline 重叠优化"主题但相隔 5 个 commit。用 `git rebase -i op/main` 加自定义 `GIT_SEQUENCE_EDITOR` / `GIT_EDITOR`(Python 脚本临时丢在 `.git/`)做 reorder + squash,得到 `01364250 perf(stream): two-pass prefetch on upload pipeline`,叙事完整。 + +之间的 4 个非 stream commit(`8a2fb995` / `9389b219` / `610c3cf9` / `65c713a0`)reorder 时无冲突——它们碰的是 driver 和 `server/static`,不沾 util.go。 + +未 squash 的 `902c8b82 feat(stream)` 是综合提交(selfHealing + link refresh + hash 工具 + 当时被丢弃的旧 prefetch),独立保留以维持叙事。 + +### 测试结果 + +| 包 | 结果 | +|---|---| +| `internal/stream` | ok(7 个 `TestHybridSectionReader_*` prefetch、2 个 `TestRefreshableRangeReader_*`、2 个 `TestSelfHealingReadCloser_*`、`TestStreamHashFile_SeekablePrefetchProducesSameHash` 全部通过) | +| `internal/fs` | ok | +| `internal/op` | ok | +| `drivers/115_open` | ok(含 `TestNewOSSUploadHttpClientHasLongerTimeout`) | +| `internal/hybrid_cache` | ok | +| `internal/frontend` | ok | +| `internal/net` | 唯一失败 `TestNewOSSClientUsesEnvironmentHTTPSProxy`(白名单,rebase 前就失败) | + +### 行动后状态 + +- `feat/dynamic-frontend` HEAD `39c74672`,19 commits ahead of `op/main`(原 20,squash -1) +- `git push --force-with-lease origin feat/dynamic-frontend` 成功(`7e140903...39c74672`) +- 本地 tag `pre-rebase-2026-05-19` 保留作为回滚点 +- `env.md` 更新 "Current upstream base" 到 `31b41f99` + +### 设计要点:byteSectionReader 消失为何不损失功能 + +旧 `byteSectionReader` 的优化逻辑:小文件用 `pool.Pool[[]byte]` 复用 buffer,避免每个 section 都分配新切片。 + +上游 `BackingStore` 的等价机制:`BufferStore` 内部就是块池化(按 `blockSize` 分块,块从 `LinearMemory` 复用),且自动按 `AutoMemoryLimit`/文件大小决策。也就是说"小文件走纯内存"这个语义保留了,只是不再暴露为独立类型。 + +代价:失去了"按文件大小选择 reader 实现"的显式分支,调试时需要看 `BackingStore` 内部状态。换取:调用方代码统一、Pass 2 prefetch 可以盲目挂在 `hybridSectionReader` 上不用考虑 byte 分支。 + +--- + +## Architecture Notes + +### Upload Data Flow (Current) +``` +Pass 1: Hash Calculation (with prefetch) + StreamHashFile → ReadFullWithRangeRead + chunk N: hash computation (CPU) + chunk N+1: async prefetch (network I/O) ← overlapped + +Pass 2: Multipart Upload (with prefetch) + hybridSectionReader.GetSectionReader + chunk N: upload to cloud (network I/O) + chunk N+1: async prefetch (network I/O) ← overlapped +``` + +### Proxy Architecture +- `conf.Conf.ProxyAddress` → global HTTP proxy for all server-side traffic +- Per-storage `WebProxy` / `DownProxyURL` → browser download only, NOT copy tasks +- OSS uploads use dedicated `NewOSSUploadHttpClient()` with longer timeouts diff --git a/OVERVIEW.md b/OVERVIEW.md new file mode 100644 index 000000000..01c71fb63 --- /dev/null +++ b/OVERVIEW.md @@ -0,0 +1,79 @@ +# OpenList Fork — Project Overview + +This is a fork of [OpenListTeam/OpenList](https://github.com/OpenListTeam/OpenList) (upstream), maintained at [Ironboxplus/OpenList](https://github.com/Ironboxplus/OpenList). + +## Branch Structure + +| Branch | Purpose | +|--------|---------| +| `feat/dynamic-frontend` | **Active development branch** (default). Rebased on `up/main`. | +| `copy` | Legacy branch with unsquashed commits. Superseded by `feat/dynamic-frontend`. | +| `main` | Synced from upstream via rebase workflow. | + +## Documentation Index + +| File | Description | +|------|-------------| +| [CLAUDE.md](CLAUDE.md) | AI coding guidance: architecture, driver system, stream/upload internals, conventions | +| [COMPATIBILITY_REPORT.md](COMPATIBILITY_REPORT.md) | 115-sdk-go fork compatibility analysis | +| [OVERVIEW.md](OVERVIEW.md) | This file — project index and high-level map | +| [JOURNAL.md](JOURNAL.md) | Chronological development log with all changes | +| [CONTRIBUTING.md](CONTRIBUTING.md) | Upstream contribution guidelines | +| [README.md](README.md) | Upstream project README | +| [SECURITY.md](SECURITY.md) | Security policy | + +## Key Subsystems Modified (vs Upstream) + +### 1. Full-Streaming Upload with Async Prefetch +- **Files**: `internal/stream/util.go`, `internal/stream/stream.go` +- Two-pass flow: hash calculation (with double-buffer prefetch) → multipart upload (with 2-window async prefetch) +- `SeekableStream` uses `RangeRead` exclusively, never consumes the `Reader` +- `selfHealingReadCloser`: transparent link refresh + reconnect-from-offset on stream interruption +- `RefreshableRangeReader`: auto-refresh expired download links (up to 50 retries) + +### 2. Dynamic Frontend Fetcher +- **Files**: `internal/frontend/fetcher.go`, `internal/frontend/watcher.go` +- Auto-downloads frontend dist from GitHub releases on startup (when `WebVersion` is rolling/beta/dev) +- `Watcher` polls every 30 min for new versions, hot-swaps dist directory +- Configurable `FrontendRepo` (default: `OpenListTeam/OpenList-Frontend`, overridable per deployment) +- `FrontendRepoDefault` injectable via ldflags at build time + +### 3. Driver Enhancements + +| Driver | Changes | +|--------|---------| +| 115 Open | Permanent delete with recycle-bin retry, FlexString CID, OSS upload timeout fix, PartAlreadyExist recovery, proxy_range, offline task multi-page | +| Google Drive | Duplicate filename handling, per-folder MakeDir lock, bounded 401 retry, MD5 checksum, mkdirLocks cleanup | +| Baidu Netdisk | Full streaming upload (extracted from monolithic driver) | +| Quark Open | Rate limiting, retry with chunk size adjustment | +| 123 Open | SHA1 rapid upload, etag fix, StreamHashFile progress normalization | +| Aliyundrive Open | StreamHashFile progress normalization | + +### 4. CI/CD & Build +- **Files**: `.github/workflows/`, `build.sh` +- Frontend version matrix (rolling + latest) for Docker builds +- `FrontendRepoDefault` x-flag in CI +- Action version upgrades (checkout v6, setup-go v6, cache v5) +- Frontend caching in CI to avoid redundant downloads +- Static linking verification for musl builds + +### 5. Offline Download +- **Files**: `internal/offline_download/115_open/`, `internal/offline_download/tool/` +- Multi-page task retrieval for 115 +- Task limit wait mechanism +- Cleanup moved to `Update()` to ensure it runs even if transfer fails + +## Upstream Sync + +Remote `up` points to `https://github.com/OpenListTeam/OpenList.git`. + +```bash +git fetch up +git rebase up/main # on feat/dynamic-frontend +``` + +Current base: `up/main` @ `c7c0cfae` (2026-05-06) + +## Global Proxy + +All server-side HTTP traffic (driver API calls, copy/move transfers, frontend fetching) uses `conf.Conf.ProxyAddress` (global setting). Per-storage `WebProxy`/`DownProxyURL` only affects browser-facing download behavior, NOT server-side copy tasks. diff --git a/README.md b/README.md index 1ec96df4d..4a94dbfcf 100644 --- a/README.md +++ b/README.md @@ -116,9 +116,10 @@ Thank you for your support and understanding of the OpenList project. ## Document -- 📘 [Global Site](https://doc.oplist.org) -- 📚 [Backup Site](https://doc.openlist.team) -- 🌏 [CN Site](https://doc.oplist.org.cn) +- 📘 [Docs](https://doc.oplist.org) +- 🌏 [CN Mirror](https://doc.oplist.org.cn) +- ⚖️ [Terms of Use](https://doc.oplist.org/terms) +- 🔒 [Privacy Policy](https://doc.oplist.org/privacy) ## Demo diff --git a/README_cn.md b/README_cn.md index 1bcfdbed9..55fe22106 100644 --- a/README_cn.md +++ b/README_cn.md @@ -116,14 +116,15 @@ OpenList 是一个由 OpenList 团队独立维护的开源项目,遵循 AGPL-3 ## 文档 -- 🌏 [国内站点](https://doc.oplist.org.cn) -- 📘 [海外站点](https://doc.oplist.org) -- 📚 [备用站点](https://doc.openlist.team) +- 📘 [文档](https://doc.oplist.org) +- 🌏 [中国镜像](https://doc.oplist.org.cn) +- ⚖️ [使用条款](https://doc.oplist.org/terms) +- 🔒 [隐私政策](https://doc.oplist.org/privacy) -## 演示 +## Demo -- 🇨🇳 [国内演示站](https://demo.oplist.org.cn) -- 🌎 [海外演示站](https://demo.oplist.org) +- 🌎 [全球 Demo](https://demo.oplist.org) +- 🇨🇳 [中国 Demo](https://demo.oplist.org.cn) ## 讨论 diff --git a/README_ja.md b/README_ja.md index 3a5d5d19f..261223de3 100644 --- a/README_ja.md +++ b/README_ja.md @@ -116,14 +116,15 @@ OpenListプロジェクトへのご支援とご理解をありがとうござい ## ドキュメント -- 📘 [グローバルサイト](https://doc.oplist.org) -- 📚 [バックアップサイト](https://doc.openlist.team) -- 🌏 [CNサイト](https://doc.oplist.org.cn) +- 📘 [ドキュメント](https://doc.oplist.org) +- 🌏 [中国ミラー](https://doc.oplist.org.cn) +- ⚖️ [利用規約](https://doc.oplist.org/terms) +- 🔒 [プライバシーポリシー](https://doc.oplist.org/privacy) -## デモ +## Demo -- 🌎 [グローバルデモ](https://demo.oplist.org) -- 🇨🇳 [CNデモ](https://demo.oplist.org.cn) +- 🌎 [グローバル Demo](https://demo.oplist.org) +- 🇨🇳 [中国 Demo](https://demo.oplist.org.cn) ## ディスカッション diff --git a/README_nl.md b/README_nl.md index 86e90e740..d3be2703f 100644 --- a/README_nl.md +++ b/README_nl.md @@ -116,9 +116,10 @@ Dank u voor uw ondersteuning en begrip ## Documentatie -- 📘 [Global Site](https://doc.oplist.org) -- 📚 [Backup Site](https://doc.openlist.team) -- 🌏 [CN Site](https://doc.oplist.org.cn) +- 📘 [Documentatie](https://doc.oplist.org) +- 🌏 [CN Mirror](https://doc.oplist.org.cn) +- ⚖️ [Gebruiksvoorwaarden](https://doc.oplist.org/terms) +- 🔒 [Privacybeleid](https://doc.oplist.org/privacy) ## Demo diff --git a/build.sh b/build.sh index b6baca1fe..4e7b4f608 100644 --- a/build.sh +++ b/build.sh @@ -5,7 +5,7 @@ gitAuthor="The OpenList Projects Contributors " gitCommit=$(git log --pretty=format:"%h" -1) # Set frontend repository, default to OpenListTeam/OpenList-Frontend -frontendRepo="${FRONTEND_REPO:-OpenListTeam/OpenList-Frontend}" +frontendRepo="${FRONTEND_REPO:-Ironboxplus/OpenList-Frontend}" githubAuthArgs="" if [ -n "$GITHUB_TOKEN" ]; then @@ -18,6 +18,8 @@ if [[ "$*" == *"lite"* ]]; then useLite=true fi +skipFrontendFetch="${SKIP_FRONTEND_FETCH:-false}" + if [ "$1" = "dev" ]; then version="dev" webVersion="rolling" @@ -31,6 +33,10 @@ else webVersion=$(eval "curl -fsSL --max-time 2 $githubAuthArgs \"https://api.github.com/repos/$frontendRepo/releases/latest\"" | grep "tag_name" | head -n 1 | awk -F ":" '{print $2}' | sed 's/\"//g;s/,//g;s/ //g') fi +if [ -n "$WEB_VERSION" ]; then + webVersion="$WEB_VERSION" +fi + echo "backend version: $version" echo "frontend version: $webVersion" if [ "$useLite" = true ]; then @@ -46,6 +52,7 @@ ldflags="\ -X 'github.com/OpenListTeam/OpenList/v4/internal/conf.GitCommit=$gitCommit' \ -X 'github.com/OpenListTeam/OpenList/v4/internal/conf.Version=$version' \ -X 'github.com/OpenListTeam/OpenList/v4/internal/conf.WebVersion=$webVersion' \ +-X 'github.com/OpenListTeam/OpenList/v4/internal/conf.FrontendRepoDefault=$frontendRepo' \ " # Keep sqlite driver tag selection centralized to avoid target drift. @@ -97,6 +104,11 @@ AssertStaticBinary() { } FetchWebRolling() { + if [ "$skipFrontendFetch" = "true" ] && [ -n "$(find public/dist -mindepth 1 -print -quit 2>/dev/null)" ]; then + echo "using cached frontend dist from public/dist" + return 0 + fi + pre_release_json=$(eval "curl -fsSL --max-time 2 $githubAuthArgs -H \"Accept: application/vnd.github.v3+json\" \"https://api.github.com/repos/$frontendRepo/releases/tags/rolling\"") pre_release_assets=$(echo "$pre_release_json" | jq -r '.assets[].browser_download_url') @@ -110,6 +122,11 @@ FetchWebRolling() { } FetchWebRelease() { + if [ "$skipFrontendFetch" = "true" ] && [ -n "$(find public/dist -mindepth 1 -print -quit 2>/dev/null)" ]; then + echo "using cached frontend dist from public/dist" + return 0 + fi + release_json=$(eval "curl -fsSL --max-time 2 $githubAuthArgs -H \"Accept: application/vnd.github.v3+json\" \"https://api.github.com/repos/$frontendRepo/releases/latest\"") release_assets=$(echo "$release_json" | jq -r '.assets[].browser_download_url') @@ -236,8 +253,8 @@ BuildDockerMultiplatform() { docker_lflags="$(GetMuslStaticLdflags)" export CGO_ENABLED=1 - OS_ARCHES=(linux-amd64 linux-arm64 linux-386 linux-riscv64 linux-ppc64le linux-loong64) ## Disable linux-s390x builds - CGO_ARGS=(x86_64-linux-musl-gcc aarch64-linux-musl-gcc i486-linux-musl-gcc riscv64-linux-musl-gcc powerpc64le-linux-musl-gcc loongarch64-linux-musl-gcc) ## Disable s390x-linux-musl-gcc builds + OS_ARCHES=(linux-amd64) ## Disable linux-s390x builds + CGO_ARGS=(x86_64-linux-musl-gcc) ## Disable s390x-linux-musl-gcc builds for i in "${!OS_ARCHES[@]}"; do os_arch=${OS_ARCHES[$i]} cgo_cc=${CGO_ARGS[$i]} @@ -257,15 +274,17 @@ BuildDockerMultiplatform() { GO_ARM=(6 7) export GOOS=linux export GOARCH=arm - for i in "${!DOCKER_ARM_ARCHES[@]}"; do - docker_arch=${DOCKER_ARM_ARCHES[$i]} - cgo_cc=${CGO_ARGS[$i]} - export GOARM=${GO_ARM[$i]} - export CC=${cgo_cc} - echo "building for $docker_arch" - CGO_LDFLAGS="-static" go build -o build/${docker_arch%%-*}/${docker_arch##*-}/"$appName" -ldflags="$docker_lflags" -tags=jsoniter . - AssertStaticBinary "build/${docker_arch%%-*}/${docker_arch##*-}/$appName" - done + # ARM docker variants stay disabled on this branch to keep the workflow x64-only. + # If they are re-enabled later, they should follow the same static-link pattern. + # for i in "${!DOCKER_ARM_ARCHES[@]}"; do + # docker_arch=${DOCKER_ARM_ARCHES[$i]} + # cgo_cc=${CGO_ARGS[$i]} + # export GOARM=${GO_ARM[$i]} + # export CC=${cgo_cc} + # echo "building for $docker_arch" + # CGO_LDFLAGS="-static" go build -o build/${docker_arch%%-*}/${docker_arch##*-}/"$appName" -ldflags="$docker_lflags" -tags=jsoniter . + # AssertStaticBinary "build/${docker_arch%%-*}/${docker_arch##*-}/$appName" + # done } BuildRelease() { @@ -655,7 +674,11 @@ if [ "$buildType" = "dev" ]; then fi elif [ "$buildType" = "release" -o "$buildType" = "beta" ]; then if [ "$buildType" = "beta" ]; then - FetchWebRolling + if [ "$WEB_VERSION" = "latest" ]; then + FetchWebRelease + else + FetchWebRolling + fi else FetchWebRelease fi diff --git a/drivers/115/driver.go b/drivers/115/driver.go index d4f5741d0..7c7c581a8 100644 --- a/drivers/115/driver.go +++ b/drivers/115/driver.go @@ -5,6 +5,7 @@ import ( "strings" "sync" + "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" @@ -68,6 +69,9 @@ func (d *Pan115) Link(ctx context.Context, file model.Obj, args model.LinkArgs) return nil, err } userAgent := args.Header.Get("User-Agent") + if userAgent == "" { + userAgent = base.UserAgent + } downloadInfo, err := d.client.DownloadWithUA(file.(*FileObj).PickCode, userAgent) if err != nil { return nil, err diff --git a/drivers/115_open/driver.go b/drivers/115_open/driver.go index 278ae0f7f..b86f7b38b 100644 --- a/drivers/115_open/driver.go +++ b/drivers/115_open/driver.go @@ -2,6 +2,8 @@ package _115_open import ( "context" + "encoding/json" + "errors" "fmt" "net/http" stdpath "path" @@ -14,11 +16,13 @@ import ( "github.com/OpenListTeam/OpenList/v4/cmd/flags" "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" + log "github.com/sirupsen/logrus" "golang.org/x/time/rate" ) @@ -30,6 +34,12 @@ type Open115 struct { parentPath string } +var ( + // 回收站列表存在短暂最终一致性延迟,永久删除 fallback 查找增加短重试。 + recycleBinLookupMaxAttempts = 4 + recycleBinLookupRetryDelay = 300 * time.Millisecond +) + func (d *Open115) Config() driver.Config { return config } @@ -99,13 +109,20 @@ func (d *Open115) Drop(ctx context.Context) error { } func (d *Open115) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + start := time.Now() + log.Infof("[115] List request started for dir: %s (ID: %s)", dir.GetName(), dir.GetID()) + var res []model.Obj pageSize := int64(d.PageSize) offset := int64(0) + pageCount := 0 + for { if err := d.WaitLimit(ctx); err != nil { return nil, err } + + pageStart := time.Now() resp, err := d.client.GetFiles(ctx, &sdk.GetFilesReq{ CID: dir.GetID(), Limit: pageSize, @@ -115,7 +132,12 @@ func (d *Open115) List(ctx context.Context, dir model.Obj, args model.ListArgs) // Cur: 1, ShowDir: true, }) + pageDuration := time.Since(pageStart) + pageCount++ + log.Infof("[115] GetFiles page %d took: %v (offset=%d, limit=%d)", pageCount, pageDuration, offset, pageSize) + if err != nil { + log.Errorf("[115] GetFiles page %d failed after %v: %v", pageCount, pageDuration, err) return nil, err } res = append(res, utils.MustSliceConvert(resp.Data, func(src sdk.GetFilesResp_File) model.Obj { @@ -127,10 +149,17 @@ func (d *Open115) List(ctx context.Context, dir model.Obj, args model.ListArgs) } offset += pageSize } + + totalDuration := time.Since(start) + log.Infof("[115] List request completed in %v (%d pages, %d files)", totalDuration, pageCount, len(res)) + return res, nil } func (d *Open115) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + start := time.Now() + log.Infof("[115] Link request started for file: %s", file.GetName()) + if err := d.WaitLimit(ctx); err != nil { return nil, err } @@ -146,20 +175,39 @@ func (d *Open115) Link(ctx context.Context, file model.Obj, args model.LinkArgs) return nil, fmt.Errorf("can't convert obj") } pc := obj.Pc + + apiStart := time.Now() + log.Infof("[115] Calling DownURL API...") resp, err := d.client.DownURL(ctx, pc, ua) + apiDuration := time.Since(apiStart) + log.Infof("[115] DownURL API took: %v", apiDuration) + if err != nil { + log.Errorf("[115] DownURL API failed after %v: %v", apiDuration, err) return nil, err } u, ok := resp[obj.GetID()] if !ok { return nil, fmt.Errorf("can't get link") } - return &model.Link{ + + totalDuration := time.Since(start) + log.Infof("[115] Link request completed in %v (API: %v)", totalDuration, apiDuration) + + link := &model.Link{ URL: u.URL.URL, Header: http.Header{ "User-Agent": []string{ua}, }, - }, nil + } + // Tie the cache TTL to the CDN's own `t=` expiry so OP never serves a + // URL that 115's CDN has already invalidated. Without this, OP would + // hand out a dead URL and 115 responds with 200 + Content-Length: 0, + // which downstream clients see as a corrupt/empty stream. + if ttl, ok := parseCDNExpiry(u.URL.URL); ok { + link.Expiration = &ttl + } + return link, nil } func (d *Open115) Get(ctx context.Context, path string) (model.Obj, error) { @@ -169,15 +217,14 @@ func (d *Open115) Get(ctx context.Context, path string) (model.Obj, error) { path = stdpath.Join(d.parentPath, path) resp, err := d.client.GetFolderInfoByPath(ctx, path) if err != nil { + if errors.Is(err, sdk.ErrDataEmpty) { + return nil, errs.ObjectNotFound + } return nil, err } - return &Obj{ - Fid: resp.FileID, - Fn: resp.FileName, - Fc: resp.FileCategory, - Sha1: resp.Sha1, - Pc: resp.PickCode, - }, nil + log.Debugf("[115] GetFolderInfoByPath(%s) => Size=%q FileCategory=%q FileID=%s", + path, resp.Size, resp.FileCategory, resp.FileID) + return fromFolderInfo(resp) } func (d *Open115) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { @@ -218,7 +265,7 @@ func (d *Open115) Rename(ctx context.Context, srcObj model.Obj, newName string) return nil, err } _, err := d.client.UpdateFile(ctx, &sdk.UpdateFileReq{ - FileID: srcObj.GetID(), + FileID: srcObj.GetID(), FileName: newName, }) if err != nil { @@ -254,42 +301,265 @@ func (d *Open115) Remove(ctx context.Context, obj model.Obj) error { if !ok { return fmt.Errorf("can't convert obj") } - _, err := d.client.DelFile(ctx, &sdk.DelFileReq{ + resp, err := d.client.DelFile(ctx, &sdk.DelFileReq{ FileIDs: _obj.GetID(), ParentID: _obj.Pid, }) if err != nil { return err } - return nil + if d.RemoveWay != "delete" { + return nil + } + return d.removePermanently(ctx, _obj, resp) } -func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { - err := d.WaitLimit(ctx) +func (d *Open115) removePermanently(ctx context.Context, obj *Obj, deleteResp []string) error { + var directDeleteErr error + for _, tid := range deleteResp { + tid = strings.TrimSpace(tid) + if tid == "" { + continue + } + if err := d.deleteRecycleBinEntry(ctx, tid); err == nil { + return nil + } else if directDeleteErr == nil { + directDeleteErr = err + } + } + + recycleEntry, err := d.findRecycleBinEntryWithRetry(ctx, obj) if err != nil { + if directDeleteErr != nil { + return fmt.Errorf("failed to permanently delete recycle-bin candidate: %w; fallback lookup failed: %v", directDeleteErr, err) + } return err } + if err := d.deleteRecycleBinEntry(ctx, recycleEntry.ID); err != nil { + if directDeleteErr != nil { + return fmt.Errorf("failed to permanently delete recycle-bin entry %s after candidate delete error %v: %w", recycleEntry.ID, directDeleteErr, err) + } + return err + } + return nil +} + +func (d *Open115) deleteRecycleBinEntry(ctx context.Context, tid string) error { + if err := d.WaitLimit(ctx); err != nil { + return err + } + _, err := d.client.RbDelete(ctx, tid) + return err +} + +func (d *Open115) findRecycleBinEntry(ctx context.Context, obj *Obj) (*sdk.RbListResp_FileInfo, error) { + pageSize := d.PageSize + if pageSize <= 0 { + pageSize = 200 + } else if pageSize > 1150 { + pageSize = 1150 + } + + offset := int64(0) + for { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + resp, err := d.client.RbList(ctx, pageSize, offset) + if err != nil { + return nil, err + } + if entry := matchRecycleBinEntry(obj, resp.Files); entry != nil { + return entry, nil + } + + count, err := strconv.ParseInt(resp.Count, 10, 64) + if err != nil { + return nil, fmt.Errorf("parse recycle bin count %q: %w", resp.Count, err) + } + offset += pageSize + if offset >= count || len(resp.Files) == 0 { + break + } + } + + return nil, fmt.Errorf("recycle bin entry not found for object id=%s name=%s parent=%s", obj.GetID(), obj.GetName(), obj.Pid) +} + +func isRecycleBinEntryNotFoundErr(err error) bool { + return err != nil && strings.Contains(err.Error(), "recycle bin entry not found") +} + +func (d *Open115) findRecycleBinEntryWithRetry(ctx context.Context, obj *Obj) (*sdk.RbListResp_FileInfo, error) { + attempts := recycleBinLookupMaxAttempts + if attempts < 1 { + attempts = 1 + } + + var lastErr error + for i := 0; i < attempts; i++ { + entry, err := d.findRecycleBinEntry(ctx, obj) + if err == nil { + return entry, nil + } + + lastErr = err + if !isRecycleBinEntryNotFoundErr(err) || i == attempts-1 { + break + } + + wait := recycleBinLookupRetryDelay * time.Duration(i+1) + if wait <= 0 { + continue + } + + timer := time.NewTimer(wait) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + + return nil, lastErr +} + +func matchRecycleBinEntry(obj *Obj, files map[string]sdk.RbListResp_FileInfo) *sdk.RbListResp_FileInfo { + if len(files) == 0 { + return nil + } + if entry, ok := files[obj.GetID()]; ok { + matched := entry + return &matched + } + + size := strconv.FormatInt(obj.GetSize(), 10) + for _, entry := range files { + if entry.ID == obj.GetID() { + matched := entry + return &matched + } + cid := string(entry.CID) + if obj.IsDir() { + if entry.FileName == obj.GetName() && cid == obj.Pid { + matched := entry + return &matched + } + continue + } + if obj.Sha1 != "" && entry.SHA1 != "" && strings.EqualFold(entry.SHA1, obj.Sha1) { + if entry.FileName == obj.GetName() || cid == obj.Pid { + matched := entry + return &matched + } + } + if entry.FileName == obj.GetName() && cid == obj.Pid && entry.FileSize == size { + matched := entry + return &matched + } + } + return nil +} + +func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + var err error sha1 := file.GetHash().GetHash(utils.SHA1) - if len(sha1) != utils.SHA1.Width { - _, sha1, err = stream.CacheFullAndHash(file, &up, utils.SHA1) + sha1128k := file.GetHash().GetHash(utils.SHA1_128K) + + // 检查是否是可重复读取的流 + _, isSeekable := file.(*stream.SeekableStream) + + // 如果有预计算的 hash,先尝试秒传 + if len(sha1) == utils.SHA1.Width && len(sha1128k) == utils.SHA1_128K.Width { + if err := d.WaitLimit(ctx); err != nil { + return err + } + resp, err := d.client.UploadInit(ctx, &sdk.UploadInitReq{ + FileName: file.GetName(), + FileSize: file.GetSize(), + Target: dstDir.GetID(), + FileID: strings.ToUpper(sha1), + PreID: strings.ToUpper(sha1128k), + }) if err != nil { return err } + if resp.Status == 2 { + up(100) + return nil + } + // 秒传失败,继续后续流程 } - const PreHashSize int64 = 128 * utils.KB - hashSize := PreHashSize - if file.GetSize() < PreHashSize { - hashSize = file.GetSize() - } - reader, err := file.RangeRead(http_range.Range{Start: 0, Length: hashSize}) - if err != nil { - return err + + if isSeekable { + // 可重复读取的流,使用 RangeRead 计算 hash,不缓存 + if len(sha1) != utils.SHA1.Width { + sha1, err = stream.StreamHashFile(file, utils.SHA1, 100, &up) + if err != nil { + return err + } + } + // 计算 sha1_128k(如果没有预计算) + if len(sha1128k) != utils.SHA1_128K.Width { + const PreHashSize int64 = 128 * utils.KB + hashSize := PreHashSize + if file.GetSize() < PreHashSize { + hashSize = file.GetSize() + } + reader, err := file.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err != nil { + return err + } + sha1128k, err = utils.HashReader(utils.SHA1, reader) + if err != nil { + return err + } + } + } else { + // 不可重复读取的流(如 HTTP body) + // 如果有预计算的 hash,上面已经尝试过秒传了 + if len(sha1) == utils.SHA1.Width && len(sha1128k) == utils.SHA1_128K.Width { + // 秒传失败,需要缓存文件进行实际上传 + _, err = file.CacheFullAndWriter(&up, nil) + if err != nil { + return err + } + } else { + // 没有预计算的 hash,缓存整个文件并计算 + if len(sha1) != utils.SHA1.Width { + _, sha1, err = stream.CacheFullAndHash(file, &up, utils.SHA1) + if err != nil { + return err + } + } else if file.GetFile() == nil { + // 有 SHA1 但没有缓存,需要缓存以支持后续 RangeRead + _, err = file.CacheFullAndWriter(&up, nil) + if err != nil { + return err + } + } + // 计算 sha1_128k + const PreHashSize int64 = 128 * utils.KB + hashSize := PreHashSize + if file.GetSize() < PreHashSize { + hashSize = file.GetSize() + } + reader, err := file.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err != nil { + return err + } + sha1128k, err = utils.HashReader(utils.SHA1, reader) + if err != nil { + return err + } + } } - sha1128k, err := utils.HashReader(utils.SHA1, reader) - if err != nil { + + // 1. Init(SeekableStream 或已缓存的 FileStream) + if err := d.WaitLimit(ctx); err != nil { return err } - // 1. Init resp, err := d.client.UploadInit(ctx, &sdk.UploadInitReq{ FileName: file.GetName(), FileSize: file.GetSize(), @@ -315,14 +585,17 @@ func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStre if err != nil { return err } - reader, err = file.RangeRead(http_range.Range{Start: start, Length: end - start + 1}) + signReader, err := file.RangeRead(http_range.Range{Start: start, Length: end - start + 1}) if err != nil { return err } - signVal, err := utils.HashReader(utils.SHA1, reader) + signVal, err := utils.HashReader(utils.SHA1, signReader) if err != nil { return err } + if err := d.WaitLimit(ctx); err != nil { + return err + } resp, err = d.client.UploadInit(ctx, &sdk.UploadInitReq{ FileName: file.GetName(), FileSize: file.GetSize(), @@ -341,6 +614,9 @@ func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStre } } // 3. get upload token + if err := d.WaitLimit(ctx); err != nil { + return err + } tokenResp, err := d.client.UploadGetToken(ctx) if err != nil { return err @@ -357,15 +633,48 @@ func (d *Open115) OfflineDownload(ctx context.Context, uris []string, dstDir mod return d.client.AddOfflineTaskURIs(ctx, uris, dstDir.GetID()) } +func (d *Open115) OfflineDownloadWithDetails(ctx context.Context, uris []string, dstDir model.Obj) ([]string, []sdk.AddOfflineTaskURIsResp, string, error) { + var envelope sdk.Resp[[]sdk.AddOfflineTaskURIsResp] + response, err := d.client.AuthRequestRaw(ctx, sdk.ApiAddOffline, http.MethodPost, nil, sdk.ReqWithForm(sdk.Form{ + "urls": strings.Join(uris, "\n"), + "wp_path_id": dstDir.GetID(), + })) + if response != nil { + _ = json.Unmarshal(response.Bytes(), &envelope) + } + hashes := make([]string, 0, len(envelope.Data)) + for _, item := range envelope.Data { + if item.State && item.InfoHash != "" { + hashes = append(hashes, item.InfoHash) + } + } + rawResponse := "" + if response != nil { + rawResponse = response.String() + } + return hashes, envelope.Data, rawResponse, err +} + func (d *Open115) DeleteOfflineTask(ctx context.Context, infoHash string, deleteFiles bool) error { return d.client.DeleteOfflineTask(ctx, infoHash, deleteFiles) } func (d *Open115) OfflineList(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + // 获取第一页 resp, err := d.client.OfflineTaskList(ctx, 1) if err != nil { return nil, err } + // 如果有多页,获取所有页面的任务 + if resp.PageCount > 1 { + for page := 2; page <= resp.PageCount; page++ { + pageResp, err := d.client.OfflineTaskList(ctx, int64(page)) + if err != nil { + return nil, err + } + resp.Tasks = append(resp.Tasks, pageResp.Tasks...) + } + } return resp, nil } diff --git a/drivers/115_open/driver_test.go b/drivers/115_open/driver_test.go new file mode 100644 index 000000000..489810327 --- /dev/null +++ b/drivers/115_open/driver_test.go @@ -0,0 +1,867 @@ +package _115_open + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "slices" + "strings" + "sync" + "testing" + "time" + + sdk "github.com/OpenListTeam/115-sdk-go" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "golang.org/x/time/rate" +) + +type recordedRequest struct { + Path string + Form url.Values + Time time.Time +} + +type rewriteTransport struct { + target *url.URL + base http.RoundTripper +} + +func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + cloned := req.Clone(req.Context()) + cloned.URL.Scheme = t.target.Scheme + cloned.URL.Host = t.target.Host + return t.base.RoundTrip(cloned) +} + +func TestOpen115RemoveTrashUsesDelFileOnly(t *testing.T) { + driver, requests := newTestOpen115(t, "trash", func(w http.ResponseWriter, r *http.Request) { + writeSDKSuccess(t, w, []string{"rb-123"}) + }) + + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", FS: 123, Sha1: "sha-demo"} + if err := driver.Remove(context.Background(), obj); err != nil { + t.Fatalf("Remove returned error: %v", err) + } + + assertRequestPaths(t, requests(), "/open/ufile/delete") + assertFormValue(t, requests()[0].Form, "file_ids", "file-1") + assertFormValue(t, requests()[0].Form, "parent_id", "dir-1") +} + +func TestOpen115RemoveDeleteUsesDelFileResponseIDWhenAvailable(t *testing.T) { + driver, requests := newTestOpen115(t, "delete", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/ufile/delete": + writeSDKSuccess(t, w, []string{"rb-123"}) + case "/open/rb/del": + writeSDKSuccess(t, w, []string{"rb-123"}) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", FS: 123, Sha1: "sha-demo"} + if err := driver.Remove(context.Background(), obj); err != nil { + t.Fatalf("Remove returned error: %v", err) + } + + assertRequestPaths(t, requests(), "/open/ufile/delete", "/open/rb/del") + assertFormValue(t, requests()[1].Form, "tid", "rb-123") +} + +func TestOpen115RemoveDeleteFallsBackToRecycleBinLookup(t *testing.T) { + driver, requests := newTestOpen115(t, "delete", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/ufile/delete": + writeSDKSuccess(t, w, []string{"file-1"}) + case "/open/rb/del": + if r.FormValue("tid") == "file-1" { + writeSDKError(t, w, 404, "not found") + return + } + writeSDKSuccess(t, w, []string{"rb-123"}) + case "/open/rb/list": + writeSDKSuccess(t, w, map[string]any{ + "offset": 0, + "limit": 1, + "count": "1", + "rb_pass": 0, + "rb-123": map[string]any{ + "id": "rb-123", + "file_name": "demo.txt", + "file_size": "123", + "cid": "dir-1", + "sha1": "sha-demo", + }, + }) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", FS: 123, Sha1: "sha-demo"} + if err := driver.Remove(context.Background(), obj); err != nil { + t.Fatalf("Remove returned error: %v", err) + } + + assertRequestPaths(t, requests(), "/open/ufile/delete", "/open/rb/del", "/open/rb/list", "/open/rb/del") + assertFormValue(t, requests()[3].Form, "tid", "rb-123") +} + +func TestOpen115RemoveDeleteReturnsErrorWhenRecycleEntryMissing(t *testing.T) { + driver, _ := newTestOpen115(t, "delete", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/ufile/delete": + writeSDKSuccess(t, w, []string{}) + case "/open/rb/list": + writeSDKSuccess(t, w, map[string]any{ + "offset": 0, + "limit": 1, + "count": "0", + "rb_pass": 0, + }) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", FS: 123, Sha1: "sha-demo"} + err := driver.Remove(context.Background(), obj) + if err == nil { + t.Fatalf("expected Remove to fail when recycle-bin entry is missing") + } + if !strings.Contains(err.Error(), "recycle bin entry not found") { + t.Fatalf("expected recycle-bin lookup error, got: %v", err) + } +} + +func TestOpen115DriverInfoIncludesRemoveWay(t *testing.T) { + info, ok := op.GetDriverInfoMap()["115 Open"] + if !ok { + t.Fatalf("115 Open driver info was not registered") + } + + for _, item := range info.Additional { + if item.Name != "remove_way" { + continue + } + if item.Type != "select" { + t.Fatalf("unexpected remove_way type: %q", item.Type) + } + if item.Options != "trash,delete" { + t.Fatalf("unexpected remove_way options: %q", item.Options) + } + if item.Default != "trash" { + t.Fatalf("unexpected remove_way default: %q", item.Default) + } + if !item.Required { + t.Fatalf("expected remove_way to be required") + } + return + } + + t.Fatalf("remove_way item not found in 115 Open driver info") +} + +func newTestOpen115(t *testing.T, removeWay string, responder http.HandlerFunc) (*Open115, func() []recordedRequest) { + t.Helper() + + var ( + mu sync.Mutex + requests []recordedRequest + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm failed: %v", err) + } + mu.Lock() + requests = append(requests, recordedRequest{ + Path: r.URL.Path, + Form: cloneValues(r.Form), + Time: time.Now(), + }) + mu.Unlock() + responder(w, r) + })) + t.Cleanup(server.Close) + + target, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Parse server URL failed: %v", err) + } + + client := sdk.New(sdk.WithAccessToken("test-token")) + client.SetHttpClient(&http.Client{ + Transport: &rewriteTransport{ + target: target, + base: http.DefaultTransport, + }, + }) + + return &Open115{ + Addition: Addition{ + RemoveWay: removeWay, + PageSize: 1, + }, + client: client, + }, func() []recordedRequest { + mu.Lock() + defer mu.Unlock() + return append([]recordedRequest(nil), requests...) + } +} + +func writeSDKSuccess(t *testing.T, w http.ResponseWriter, data any) { + t.Helper() + writeSDKResponse(t, w, map[string]any{ + "state": true, + "data": data, + }) +} + +func writeSDKError(t *testing.T, w http.ResponseWriter, code int64, message string) { + t.Helper() + writeSDKResponse(t, w, map[string]any{ + "state": false, + "code": code, + "message": message, + }) +} + +func writeSDKResponse(t *testing.T, w http.ResponseWriter, payload map[string]any) { + t.Helper() + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(payload); err != nil { + t.Fatalf("Encode response failed: %v", err) + } +} + +func assertRequestPaths(t *testing.T, requests []recordedRequest, want ...string) { + t.Helper() + got := make([]string, 0, len(requests)) + for _, req := range requests { + got = append(got, req.Path) + } + if !slices.Equal(got, want) { + t.Fatalf("unexpected request paths: got %v want %v", got, want) + } +} + +func assertFormValue(t *testing.T, form url.Values, key, want string) { + t.Helper() + if got := form.Get(key); got != want { + t.Fatalf("unexpected form value for %s: got %q want %q", key, got, want) + } +} + +func cloneValues(src url.Values) url.Values { + dst := make(url.Values, len(src)) + for key, values := range src { + dst[key] = append([]string(nil), values...) + } + return dst +} + +// --- FlexString / numeric CID tests --- + +func TestFlexStringUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"string value", `{"cid":"dir-1"}`, "dir-1"}, + {"integer value", `{"cid":3383942108160578280}`, "3383942108160578280"}, + {"large integer", `{"cid":9999999999999999999}`, "9999999999999999999"}, + {"zero", `{"cid":0}`, "0"}, + {"negative", `{"cid":-123}`, "-123"}, + {"float", `{"cid":1.5}`, "1.5"}, + {"empty string", `{"cid":""}`, ""}, + {"null", `{"cid":null}`, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var v struct { + CID sdk.FlexString `json:"cid"` + } + if err := json.Unmarshal([]byte(tt.input), &v); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if got := string(v.CID); got != tt.want { + t.Fatalf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestFlexStringUnmarshalInvalid(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"boolean", `{"cid":true}`}, + {"array", `{"cid":[1]}`}, + {"object", `{"cid":{}}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var v struct { + CID sdk.FlexString `json:"cid"` + } + if err := json.Unmarshal([]byte(tt.input), &v); err == nil { + t.Fatalf("expected error for input %s", tt.input) + } + }) + } +} + +func TestRbListRespUnmarshalCIDAsString(t *testing.T) { + raw := `{"id":"rb-1","file_name":"demo.txt","cid":"dir-1","file_size":"123"}` + var info sdk.RbListResp_FileInfo + if err := json.Unmarshal([]byte(raw), &info); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if string(info.CID) != "dir-1" { + t.Fatalf("got CID %q, want %q", string(info.CID), "dir-1") + } +} + +func TestRbListRespUnmarshalCIDAsNumber(t *testing.T) { + raw := `{"id":"rb-1","file_name":"MyFolder","cid":3383942108160578280,"file_size":"0"}` + var info sdk.RbListResp_FileInfo + if err := json.Unmarshal([]byte(raw), &info); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if string(info.CID) != "3383942108160578280" { + t.Fatalf("got CID %q, want %q", string(info.CID), "3383942108160578280") + } +} + +// --- matchRecycleBinEntry with numeric CID --- + +func TestMatchRecycleBinEntryDirMatchWithNumericCID(t *testing.T) { + // obj represents a directory with Pid (parent ID) as a large number + obj := &Obj{Fid: "folder-1", Pid: "3383942108160578280", Fn: "MyFolder", Fc: "0", FS: 0} + files := map[string]sdk.RbListResp_FileInfo{ + "rb-1": { + ID: "rb-folder-1", + FileName: "MyFolder", + CID: sdk.FlexString("3383942108160578280"), + }, + } + result := matchRecycleBinEntry(obj, files) + if result == nil { + t.Fatal("expected match for directory with numeric CID, got nil") + } + if result.ID != "rb-folder-1" { + t.Fatalf("got ID %q, want %q", result.ID, "rb-folder-1") + } +} + +func TestMatchRecycleBinEntryDirNoMatchWhenCIDWrong(t *testing.T) { + obj := &Obj{Fid: "folder-1", Pid: "3383942108160578280", Fn: "MyFolder", Fc: "0", FS: 0} + files := map[string]sdk.RbListResp_FileInfo{ + "rb-1": { + ID: "rb-folder-1", + FileName: "MyFolder", + CID: sdk.FlexString("9999999999"), + }, + } + result := matchRecycleBinEntry(obj, files) + if result != nil { + t.Fatalf("expected no match, got %+v", result) + } +} + +func TestMatchRecycleBinEntryFileSHA1MatchWithNumericCID(t *testing.T) { + obj := &Obj{Fid: "file-1", Pid: "3383942108160578280", Fn: "video.mp4", Fc: "1", FS: 1024, Sha1: "abc123"} + files := map[string]sdk.RbListResp_FileInfo{ + "rb-1": { + ID: "rb-file-1", + FileName: "video.mp4", + CID: sdk.FlexString("3383942108160578280"), + SHA1: "ABC123", + FileSize: "1024", + }, + } + result := matchRecycleBinEntry(obj, files) + if result == nil { + t.Fatal("expected match via SHA1+CID, got nil") + } + if result.ID != "rb-file-1" { + t.Fatalf("got ID %q, want %q", result.ID, "rb-file-1") + } +} + +func TestMatchRecycleBinEntryFileSHA1MatchByNameOnly(t *testing.T) { + obj := &Obj{Fid: "file-1", Pid: "wrong-pid", Fn: "video.mp4", Fc: "1", FS: 1024, Sha1: "abc123"} + files := map[string]sdk.RbListResp_FileInfo{ + "rb-1": { + ID: "rb-file-1", + FileName: "video.mp4", + CID: sdk.FlexString("3383942108160578280"), + SHA1: "ABC123", + FileSize: "1024", + }, + } + result := matchRecycleBinEntry(obj, files) + if result == nil { + t.Fatal("expected match via SHA1+name, got nil") + } +} + +func TestMatchRecycleBinEntryFileNameSizeCIDMatch(t *testing.T) { + obj := &Obj{Fid: "file-1", Pid: "3383942108160578280", Fn: "doc.pdf", Fc: "1", FS: 500} + files := map[string]sdk.RbListResp_FileInfo{ + "rb-1": { + ID: "rb-file-1", + FileName: "doc.pdf", + CID: sdk.FlexString("3383942108160578280"), + FileSize: "500", + }, + } + result := matchRecycleBinEntry(obj, files) + if result == nil { + t.Fatal("expected match via name+size+CID, got nil") + } +} + +func TestMatchRecycleBinEntryDirectIDMatch(t *testing.T) { + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", Fc: "1", FS: 123} + files := map[string]sdk.RbListResp_FileInfo{ + "file-1": { + ID: "rb-123", + FileName: "demo.txt", + CID: sdk.FlexString("dir-1"), + }, + } + result := matchRecycleBinEntry(obj, files) + if result == nil { + t.Fatal("expected direct ID match, got nil") + } + if result.ID != "rb-123" { + t.Fatalf("got ID %q, want %q", result.ID, "rb-123") + } +} + +func TestMatchRecycleBinEntryEmptyFiles(t *testing.T) { + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", Fc: "1", FS: 123} + result := matchRecycleBinEntry(obj, nil) + if result != nil { + t.Fatalf("expected nil for nil files, got %+v", result) + } + result = matchRecycleBinEntry(obj, map[string]sdk.RbListResp_FileInfo{}) + if result != nil { + t.Fatalf("expected nil for empty files, got %+v", result) + } +} + +// --- Full Remove flow with numeric CID in recycle bin --- + +func TestOpen115RemoveDeleteWithNumericCID(t *testing.T) { + driver, requests := newTestOpen115(t, "delete", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/ufile/delete": + writeSDKSuccess(t, w, []string{"folder-1"}) + case "/open/rb/del": + if r.FormValue("tid") == "folder-1" { + writeSDKError(t, w, 404, "not found") + return + } + writeSDKSuccess(t, w, []string{"rb-folder-1"}) + case "/open/rb/list": + // CID returned as number (the real bug scenario) + writeSDKSuccess(t, w, map[string]any{ + "offset": 0, + "limit": 1, + "count": "1", + "rb_pass": 0, + "rb-folder-1": map[string]any{ + "id": "rb-folder-1", + "file_name": "MyFolder", + "cid": 3383942108160578280, + "file_size": "0", + }, + }) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + obj := &Obj{Fid: "folder-1", Pid: "3383942108160578280", Fn: "MyFolder", Fc: "0", FS: 0} + if err := driver.Remove(context.Background(), obj); err != nil { + t.Fatalf("Remove returned error: %v", err) + } + + assertRequestPaths(t, requests(), "/open/ufile/delete", "/open/rb/del", "/open/rb/list", "/open/rb/del") + assertFormValue(t, requests()[3].Form, "tid", "rb-folder-1") +} + +func TestOpen115RemoveDeleteWithStringCIDStillWorks(t *testing.T) { + driver, requests := newTestOpen115(t, "delete", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/ufile/delete": + writeSDKSuccess(t, w, []string{"file-1"}) + case "/open/rb/del": + if r.FormValue("tid") == "file-1" { + writeSDKError(t, w, 404, "not found") + return + } + writeSDKSuccess(t, w, []string{"rb-123"}) + case "/open/rb/list": + // CID returned as string (normal case) + writeSDKSuccess(t, w, map[string]any{ + "offset": 0, + "limit": 1, + "count": "1", + "rb_pass": 0, + "rb-123": map[string]any{ + "id": "rb-123", + "file_name": "demo.txt", + "cid": "dir-1", + "sha1": "sha-demo", + "file_size": "123", + }, + }) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", Fc: "1", FS: 123, Sha1: "sha-demo"} + if err := driver.Remove(context.Background(), obj); err != nil { + t.Fatalf("Remove returned error: %v", err) + } + + assertRequestPaths(t, requests(), "/open/ufile/delete", "/open/rb/del", "/open/rb/list", "/open/rb/del") + assertFormValue(t, requests()[3].Form, "tid", "rb-123") +} + +func TestOpen115RemoveDeleteRetriesRecycleBinLookupUntilVisible(t *testing.T) { + oldAttempts, oldDelay := recycleBinLookupMaxAttempts, recycleBinLookupRetryDelay + recycleBinLookupMaxAttempts = 3 + recycleBinLookupRetryDelay = time.Millisecond + t.Cleanup(func() { + recycleBinLookupMaxAttempts = oldAttempts + recycleBinLookupRetryDelay = oldDelay + }) + + rbListCalls := 0 + driver, requests := newTestOpen115(t, "delete", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/ufile/delete": + writeSDKSuccess(t, w, []string{"file-1"}) + case "/open/rb/del": + if r.FormValue("tid") == "file-1" { + writeSDKError(t, w, 404, "not found") + return + } + writeSDKSuccess(t, w, []string{"rb-123"}) + case "/open/rb/list": + rbListCalls++ + if rbListCalls < 3 { + writeSDKSuccess(t, w, map[string]any{ + "offset": 0, + "limit": 1, + "count": "0", + "rb_pass": 0, + }) + return + } + writeSDKSuccess(t, w, map[string]any{ + "offset": 0, + "limit": 1, + "count": "1", + "rb_pass": 0, + "rb-123": map[string]any{ + "id": "rb-123", + "file_name": "demo.txt", + "cid": "dir-1", + "sha1": "sha-demo", + "file_size": "123", + }, + }) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", Fc: "1", FS: 123, Sha1: "sha-demo"} + if err := driver.Remove(context.Background(), obj); err != nil { + t.Fatalf("Remove returned error: %v", err) + } + + if rbListCalls != 3 { + t.Fatalf("rbListCalls = %d, want 3", rbListCalls) + } + assertRequestPaths(t, requests(), "/open/ufile/delete", "/open/rb/del", "/open/rb/list", "/open/rb/list", "/open/rb/list", "/open/rb/del") + assertFormValue(t, requests()[5].Form, "tid", "rb-123") +} + +func TestOpen115RemoveDeleteStopsRetryWhenContextCancelled(t *testing.T) { + oldAttempts, oldDelay := recycleBinLookupMaxAttempts, recycleBinLookupRetryDelay + recycleBinLookupMaxAttempts = 5 + recycleBinLookupRetryDelay = 50 * time.Millisecond + t.Cleanup(func() { + recycleBinLookupMaxAttempts = oldAttempts + recycleBinLookupRetryDelay = oldDelay + }) + + driver, _ := newTestOpen115(t, "delete", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/ufile/delete": + writeSDKSuccess(t, w, []string{"file-1"}) + case "/open/rb/del": + writeSDKError(t, w, 404, "not found") + case "/open/rb/list": + writeSDKSuccess(t, w, map[string]any{ + "offset": 0, + "limit": 1, + "count": "0", + "rb_pass": 0, + }) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + obj := &Obj{Fid: "file-1", Pid: "dir-1", Fn: "demo.txt", Fc: "1", FS: 123, Sha1: "sha-demo"} + err := driver.Remove(ctx, obj) + if err == nil { + t.Fatalf("expected Remove to fail due to context cancellation") + } + if !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) { + t.Fatalf("expected context deadline exceeded, got: %v", err) + } +} + +// --- Put rate-limiting tests --- + +// mockFileStreamer satisfies model.FileStreamer with pre-computed hashes for testing. +type mockFileStreamer struct { + name string + size int64 + hashInfo utils.HashInfo + data []byte +} + +func (m *mockFileStreamer) Read(p []byte) (int, error) { return 0, io.EOF } +func (m *mockFileStreamer) Close() error { return nil } +func (m *mockFileStreamer) Add(_ io.Closer) {} +func (m *mockFileStreamer) AddIfCloser(_ any) {} +func (m *mockFileStreamer) GetSize() int64 { return m.size } +func (m *mockFileStreamer) GetName() string { return m.name } +func (m *mockFileStreamer) ModTime() time.Time { return time.Time{} } +func (m *mockFileStreamer) CreateTime() time.Time { return time.Time{} } +func (m *mockFileStreamer) IsDir() bool { return false } +func (m *mockFileStreamer) GetHash() utils.HashInfo { return m.hashInfo } +func (m *mockFileStreamer) GetID() string { return "" } +func (m *mockFileStreamer) GetPath() string { return "" } +func (m *mockFileStreamer) GetMimetype() string { return "application/octet-stream" } +func (m *mockFileStreamer) NeedStore() bool { return false } +func (m *mockFileStreamer) IsForceStreamUpload() bool { return false } +func (m *mockFileStreamer) GetExist() model.Obj { return nil } +func (m *mockFileStreamer) SetExist(_ model.Obj) {} +func (m *mockFileStreamer) GetFile() model.File { return nil } +func (m *mockFileStreamer) RangeRead(_ http_range.Range) (io.Reader, error) { + return strings.NewReader(string(m.data)), nil +} +func (m *mockFileStreamer) CacheFullAndWriter(_ *model.UpdateProgress, _ io.Writer) (model.File, error) { + return nil, nil +} + +func newTestOpen115WithRateLimit(t *testing.T, limitRate float64, responder http.HandlerFunc) (*Open115, func() []recordedRequest) { + t.Helper() + driver, requests := newTestOpen115(t, "trash", responder) + driver.limiter = rate.NewLimiter(rate.Limit(limitRate), 1) + return driver, requests +} + +func TestPutRateLimitsEverySDKCall(t *testing.T) { + // Rate limit: 10 req/s → each WaitLimit blocks ~100ms + const limitRate = 10.0 + + driver, requests := newTestOpen115WithRateLimit(t, limitRate, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/upload/init": + // First call: status=1 (not rapid), second call: status=2 (rapid success) + if r.FormValue("sign_key") != "" { + writeSDKSuccess(t, w, map[string]any{"status": 2}) + } else { + writeSDKSuccess(t, w, map[string]any{ + "status": 7, + "sign_key": "test-key", + "sign_check": "0-10", + }) + } + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + stream := &mockFileStreamer{ + name: "test.txt", + size: 100, + hashInfo: utils.NewHashInfoByMap(map[*utils.HashType]string{ + utils.SHA1: "da39a3ee5e6b4b0d3255bfef95601890afd80709", + utils.SHA1_128K: "da39a3ee5e6b4b0d3255bfef95601890afd80709", + }), + data: make([]byte, 100), + } + dstDir := &model.Object{ID: "0", Name: "root", IsFolder: true} + up := func(float64) {} + + start := time.Now() + err := driver.Put(context.Background(), dstDir, stream, up) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("Put returned error: %v", err) + } + + reqs := requests() + // Expect: pre-hash UploadInit + main UploadInit + sign-check UploadInit = 3 calls + if len(reqs) < 3 { + t.Fatalf("expected at least 3 requests, got %d", len(reqs)) + } + assertRequestPaths(t, reqs, "/open/upload/init", "/open/upload/init", "/open/upload/init") + + // With 2 SDK calls each preceded by WaitLimit(10/s), minimum elapsed is ~100ms. + // Without WaitLimit, both calls fire instantly (<10ms). + minExpected := time.Duration(float64(time.Second) / limitRate * float64(len(reqs)-1)) + tolerance := minExpected * 7 / 10 // 70% to account for timing jitter + if elapsed < tolerance { + t.Fatalf("Put completed too fast (%v), expected at least %v — WaitLimit likely missing before some SDK calls", elapsed, tolerance) + } + + // Also verify individual request spacing + for i := 1; i < len(reqs); i++ { + gap := reqs[i].Time.Sub(reqs[i-1].Time) + gapMin := time.Duration(float64(time.Second) / limitRate * 0.7) + if gap < gapMin { + t.Fatalf("gap between request %d and %d is %v, expected at least %v — WaitLimit missing", i-1, i, gap, gapMin) + } + } +} + +func TestPutRateLimitsPreHashPath(t *testing.T) { + const limitRate = 10.0 + + driver, requests := newTestOpen115WithRateLimit(t, limitRate, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/open/upload/init": + // Rapid upload success on first try + writeSDKSuccess(t, w, map[string]any{"status": 2}) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + }) + + stream := &mockFileStreamer{ + name: "test.txt", + size: 100, + hashInfo: utils.NewHashInfoByMap(map[*utils.HashType]string{ + utils.SHA1: "da39a3ee5e6b4b0d3255bfef95601890afd80709", + utils.SHA1_128K: "da39a3ee5e6b4b0d3255bfef95601890afd80709", + }), + data: make([]byte, 100), + } + dstDir := &model.Object{ID: "0", Name: "root", IsFolder: true} + + err := driver.Put(context.Background(), dstDir, stream, func(float64) {}) + if err != nil { + t.Fatalf("Put returned error: %v", err) + } + + reqs := requests() + if len(reqs) != 1 { + t.Fatalf("expected 1 request, got %d: %v", len(reqs), reqs) + } + assertRequestPaths(t, reqs, "/open/upload/init") +} + +func TestGetReturnsObjectNotFoundForEmptyData(t *testing.T) { + driver, _ := newTestOpen115(t, "trash", func(w http.ResponseWriter, r *http.Request) { + // 115 returns empty array when path doesn't exist + writeSDKSuccess(t, w, []any{}) + }) + driver.parentPath = "" + + _, err := driver.Get(context.Background(), "/nonexistent/path") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, errs.ObjectNotFound) { + t.Fatalf("expected ObjectNotFound, got: %v", err) + } +} + +func TestCheckUploadCallbackSuccess(t *testing.T) { + body := []byte(`{"state":true,"code":0,"message":"success","data":{"pick_code":"abc","file_name":"test.txt","file_size":100,"file_id":"123","sha1":"da39a3ee","cid":"456"}}`) + if err := checkUploadCallback(body); err != nil { + t.Fatalf("expected nil error, got: %v", err) + } +} + +func TestCheckUploadCallbackStateFalse(t *testing.T) { + body := []byte(`{"state":false,"code":990009,"message":"upload failed"}`) + err := checkUploadCallback(body) + if err == nil { + t.Fatal("expected error for state=false, got nil") + } + if !strings.Contains(err.Error(), "990009") || !strings.Contains(err.Error(), "upload failed") { + t.Fatalf("error should contain code and message, got: %v", err) + } +} + +func TestCheckUploadCallbackEmptyBody(t *testing.T) { + err := checkUploadCallback([]byte{}) + if err == nil { + t.Fatal("expected error for empty body, got nil") + } + if !strings.Contains(err.Error(), "empty") { + t.Fatalf("error should mention empty, got: %v", err) + } +} + +func TestCheckUploadCallbackInvalidJSON(t *testing.T) { + err := checkUploadCallback([]byte(`not json`)) + if err == nil { + t.Fatal("expected error for invalid JSON, got nil") + } + if !strings.Contains(err.Error(), "parse error") { + t.Fatalf("error should mention parse error, got: %v", err) + } +} + +func TestGetReturnsObjForExistingFolder(t *testing.T) { + driver, _ := newTestOpen115(t, "trash", func(w http.ResponseWriter, r *http.Request) { + writeSDKSuccess(t, w, map[string]any{ + "file_id": "99999", + "file_name": "my_folder", + "pick_code": "pc-123", + "file_category": "0", // folder; Fix 4 rejects file responses with NotImplement + }) + }) + driver.parentPath = "" + + obj, err := driver.Get(context.Background(), "/my_folder") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if obj.GetID() != "99999" || obj.GetName() != "my_folder" { + t.Fatalf("unexpected obj: id=%s name=%s", obj.GetID(), obj.GetName()) + } +} diff --git a/drivers/115_open/get.go b/drivers/115_open/get.go new file mode 100644 index 000000000..d36f0c97f --- /dev/null +++ b/drivers/115_open/get.go @@ -0,0 +1,46 @@ +package _115_open + +import ( + sdk "github.com/OpenListTeam/115-sdk-go" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" +) + +// folderInfoToObj extracts the fields driver.Get needs from a +// GetFolderInfoByPath response. We only look at FileID + FileName + +// FileCategory — Sha1 / PickCode come along because the struct already +// carries them, but Size is intentionally dropped: the API is +// folder-oriented and resp.Size is empty/garbage for file paths. +// fromFolderInfo rejects files outright so the parsed value would never +// be read anyway. +func folderInfoToObj(resp *sdk.GetFolderInfoResp) *Obj { + return &Obj{ + Fid: resp.FileID, + Fn: resp.FileName, + Fc: resp.FileCategory, + Sha1: resp.Sha1, + Pc: resp.PickCode, + } +} + +// fromFolderInfo is the gateway used by driver.Get. Files are rejected +// with errs.NotImplement so op.Get falls through to its list-based +// path — that response carries GetFilesResp_File.FS (int64) and is +// always correct. Folders take the fast path: one +// GetFolderInfoByPath call is enough to build the Obj that +// op.list will pass to driver.List as the parent directory. +// +// Trade-off: cold-cache file access pays one wasted +// GetFolderInfoByPath (this call) + one folder GetFolderInfoByPath +// (parent) + one GetFiles (list parent) + the eventual DownURL = 4 +// WaitLimit-gated SDK calls. Default limit_rate is 1 req/s so this +// is ~3s of pure rate-limit wait on top of network. Steady-state +// access within dirCache TTL (5 min) collapses back to a single +// DownURL call. +func fromFolderInfo(resp *sdk.GetFolderInfoResp) (model.Obj, error) { + obj := folderInfoToObj(resp) + if !obj.IsDir() { + return nil, errs.NotImplement + } + return obj, nil +} diff --git a/drivers/115_open/get_test.go b/drivers/115_open/get_test.go new file mode 100644 index 000000000..638e65b69 --- /dev/null +++ b/drivers/115_open/get_test.go @@ -0,0 +1,59 @@ +package _115_open + +import ( + "errors" + "testing" + + sdk "github.com/OpenListTeam/115-sdk-go" + "github.com/OpenListTeam/OpenList/v4/internal/errs" +) + +// driver.Get's only job after Fix 4 is "is this a folder, and if so what's +// its FileID/name". Files are rejected with errs.NotImplement so op.Get +// falls back to the list-based path — that response carries +// GetFilesResp_File.FS as int64 and is always correct. The 115 +// "Get folder info" API (yuque rl8zrhe2nag21dfw) is folder-oriented and +// resp.Size is unreliable for file paths anyway. + +func TestFolderInfoToObj_Folder(t *testing.T) { + resp := &sdk.GetFolderInfoResp{ + FileID: "1234", FileName: "Movies", FileCategory: "0", + } + obj := folderInfoToObj(resp) + if !obj.IsDir() { + t.Fatalf("IsDir() = false, want true (FileCategory 0 = folder)") + } + if obj.GetID() != "1234" || obj.GetName() != "Movies" { + t.Fatalf("obj = %+v, want id=1234 name=Movies", obj) + } +} + +func TestFromFolderInfo_FileReturnsNotImplement(t *testing.T) { + // File path response: even with sha1/pickcode populated, fast-path + // must defer to list — Size is not byte-accurate for files. + resp := &sdk.GetFolderInfoResp{ + FileID: "3370024725891437210", FileName: "Django Unchained 2012.mkv", + FileCategory: "1", Sha1: "17C6301550DE8E22D477EB9BA3901A99B9961494", + PickCode: "cqghjg71avvncddhi", Size: "36767958354", + } + obj, err := fromFolderInfo(resp) + if obj != nil { + t.Fatalf("obj = %+v, want nil (files must defer to List)", obj) + } + if !errors.Is(err, errs.NotImplement) { + t.Fatalf("err = %v, want errs.NotImplement (so op.Get falls through to list path)", err) + } +} + +func TestFromFolderInfo_FolderReturnsObj(t *testing.T) { + resp := &sdk.GetFolderInfoResp{ + FileID: "1234", FileName: "Movies", FileCategory: "0", + } + obj, err := fromFolderInfo(resp) + if err != nil { + t.Fatalf("unexpected err = %v", err) + } + if obj == nil || !obj.IsDir() || obj.GetID() != "1234" { + t.Fatalf("obj = %+v, want non-nil folder with ID 1234", obj) + } +} diff --git a/drivers/115_open/link_expiry.go b/drivers/115_open/link_expiry.go new file mode 100644 index 000000000..f92f401c4 --- /dev/null +++ b/drivers/115_open/link_expiry.go @@ -0,0 +1,40 @@ +package _115_open + +import ( + "net/url" + "strconv" + "time" +) + +// 115 CDN download URLs carry `?t=` marking when the signed URL stops +// serving bytes. The real lifetime is often a few minutes, well below +// OpenList's default link-cache TTL — when the cache outlives the URL, OP +// hands out an expired CDN link and the response is "200 OK + empty body". +// +// parseCDNExpiry extracts that timestamp and turns it into a TTL suitable +// for model.Link.Expiration, with a safety margin so OP refreshes slightly +// before the CDN actually rejects. +const ( + cdnExpirySafetyMargin = 60 * time.Second + cdnExpiryMinimum = 1 * time.Second +) + +func parseCDNExpiry(rawURL string) (time.Duration, bool) { + parsed, err := url.Parse(rawURL) + if err != nil { + return 0, false + } + tStr := parsed.Query().Get("t") + if tStr == "" { + return 0, false + } + tUnix, err := strconv.ParseInt(tStr, 10, 64) + if err != nil { + return 0, false + } + remaining := time.Until(time.Unix(tUnix, 0)) - cdnExpirySafetyMargin + if remaining < cdnExpiryMinimum { + return cdnExpiryMinimum, true + } + return remaining, true +} diff --git a/drivers/115_open/link_expiry_test.go b/drivers/115_open/link_expiry_test.go new file mode 100644 index 000000000..4a19686ce --- /dev/null +++ b/drivers/115_open/link_expiry_test.go @@ -0,0 +1,150 @@ +package _115_open + +import ( + "fmt" + "testing" + "time" +) + +// 115 CDN download URLs carry a Unix timestamp in `t=...` that marks when +// the signed URL stops accepting requests. The actual lifetime is typically +// just a few minutes — much shorter than OpenList's default link-cache TTL. +// `parseCDNExpiry` reads `t=` and returns the time-to-live to plug into +// model.Link.Expiration so OP refreshes the link before the CDN does. + +func TestParseCDNExpiry_FutureTimestamp(t *testing.T) { + future := time.Now().Add(30 * time.Minute).Unix() + url := fmt.Sprintf("https://cdnfhnfdfs.115cdn.net/group518/foo.mkv?t=%d&u=123&s=52428800", future) + + d, ok := parseCDNExpiry(url) + if !ok { + t.Fatalf("ok = false, want true for URL with future t=") + } + // Expect 30min minus the safety margin. Allow ±5s slack for the + // time.Now() jitter between Unix() above and time.Until() inside. + want := 30*time.Minute - cdnExpirySafetyMargin + if d < want-5*time.Second || d > want+5*time.Second { + t.Fatalf("d = %v, want ~%v", d, want) + } +} + +func TestParseCDNExpiry_PastTimestamp_ClampedToMinimum(t *testing.T) { + past := time.Now().Add(-10 * time.Minute).Unix() + url := fmt.Sprintf("https://cdn.example.com/foo?t=%d", past) + + d, ok := parseCDNExpiry(url) + if !ok { + t.Fatalf("ok = false, want true even for expired t= (caller decides what to do)") + } + if d != cdnExpiryMinimum { + t.Fatalf("d = %v, want clamp to cdnExpiryMinimum=%v", d, cdnExpiryMinimum) + } +} + +func TestParseCDNExpiry_AboutToExpire_ClampedToMinimum(t *testing.T) { + // 30 seconds in the future, less than safety margin (60s) → would yield + // a negative duration. Should clamp. + soon := time.Now().Add(30 * time.Second).Unix() + url := fmt.Sprintf("https://cdn.example.com/foo?t=%d", soon) + + d, ok := parseCDNExpiry(url) + if !ok { + t.Fatalf("ok = false, want true") + } + if d != cdnExpiryMinimum { + t.Fatalf("d = %v, want clamp to cdnExpiryMinimum=%v", d, cdnExpiryMinimum) + } +} + +func TestParseCDNExpiry_MissingT(t *testing.T) { + d, ok := parseCDNExpiry("https://cdn.example.com/foo?u=123&s=52428800") + if ok { + t.Fatalf("ok = true, want false for URL without t=") + } + if d != 0 { + t.Fatalf("d = %v, want 0", d) + } +} + +func TestParseCDNExpiry_MalformedT(t *testing.T) { + d, ok := parseCDNExpiry("https://cdn.example.com/foo?t=notanumber&u=123") + if ok { + t.Fatalf("ok = true, want false for malformed t=") + } + if d != 0 { + t.Fatalf("d = %v, want 0", d) + } +} + +func TestParseCDNExpiry_EmptyT(t *testing.T) { + d, ok := parseCDNExpiry("https://cdn.example.com/foo?t=&u=123") + if ok { + t.Fatalf("ok = true, want false for empty t=") + } + if d != 0 { + t.Fatalf("d = %v, want 0", d) + } +} + +func TestParseCDNExpiry_NegativeT(t *testing.T) { + d, ok := parseCDNExpiry("https://cdn.example.com/foo?t=-1") + if ok { + // negative parses as Int64 but maps to a past time → ok=true with clamp + if d != cdnExpiryMinimum { + t.Fatalf("d = %v, want clamp to cdnExpiryMinimum=%v", d, cdnExpiryMinimum) + } + } else { + // alternative acceptable behavior: reject negative as malformed. + // Either policy is defensible; pin whichever the implementation chose. + if d != 0 { + t.Fatalf("d = %v, want 0", d) + } + } +} + +func TestParseCDNExpiry_InvalidURL(t *testing.T) { + // url.Parse is permissive — most "garbage in" still parses, but the + // query bag is empty so t= is missing. + d, ok := parseCDNExpiry("not_a_url_at_all") + if ok { + t.Fatalf("ok = true, want false for garbage URL with no query") + } + if d != 0 { + t.Fatalf("d = %v, want 0", d) + } +} + +func TestParseCDNExpiry_RealWorld115URL(t *testing.T) { + // Anchored on the actual URL the user pasted, but with t= rewritten to a + // known future point so the test isn't time-bombed. + future := time.Now().Add(2 * time.Hour).Unix() + url := fmt.Sprintf( + "https://cdnfhnfdfs.115cdn.net/group518/M00/5B/9F/tzyQp1JGFCUAAAAIj4qFUklVFs09176478/Django%%20Unchained%%202012.mkv?t=%d&u=103088508&s=52428800&d=vip-2559104837-cqghjg71avvncddhi-1-100195313&c=2&f=1&k=44423951d519b201c3b42e03556e724b&us=62914560&uc=10&v=1", + future, + ) + d, ok := parseCDNExpiry(url) + if !ok { + t.Fatalf("ok = false on real-world URL") + } + if d <= 0 || d > 2*time.Hour { + t.Fatalf("d = %v, want in (0, 2h]", d) + } +} + +func TestParseCDNExpiry_SafetyMarginApplied(t *testing.T) { + // Verify the margin actually shrinks the returned duration. Use a fixed + // offset much larger than the safety margin so jitter doesn't matter. + raw := 1 * time.Hour + future := time.Now().Add(raw).Unix() + url := fmt.Sprintf("https://cdn.example.com/foo?t=%d", future) + d, ok := parseCDNExpiry(url) + if !ok { + t.Fatalf("ok = false") + } + if d >= raw { + t.Fatalf("d = %v, expected strictly less than raw=%v (safety margin not applied)", d, raw) + } + if d < raw-2*cdnExpirySafetyMargin { + t.Fatalf("d = %v, expected ≥ raw-2*margin=%v (margin too aggressive)", d, raw-2*cdnExpirySafetyMargin) + } +} diff --git a/drivers/115_open/meta.go b/drivers/115_open/meta.go index ed908e2e6..0479ac1fd 100644 --- a/drivers/115_open/meta.go +++ b/drivers/115_open/meta.go @@ -11,6 +11,7 @@ type Addition struct { // define other OrderBy string `json:"order_by" type:"select" options:"file_name,file_size,user_utime,file_type"` OrderDirection string `json:"order_direction" type:"select" options:"asc,desc"` + RemoveWay string `json:"remove_way" required:"true" type:"select" options:"trash,delete" default:"trash"` LimitRate float64 `json:"limit_rate" type:"float" default:"1" help:"limit all api request rate ([limit]r/1s)"` PageSize int64 `json:"page_size" type:"number" default:"200" help:"list api per page size of 115open driver"` AccessToken string `json:"access_token" required:"true"` @@ -18,9 +19,10 @@ type Addition struct { } var config = driver.Config{ - Name: "115 Open", - DefaultRoot: "0", - LinkCacheMode: driver.LinkCacheUA, + Name: "115 Open", + DefaultRoot: "0", + ProxyRangeOption: true, + LinkCacheMode: driver.LinkCacheUA, } func init() { diff --git a/drivers/115_open/upload.go b/drivers/115_open/upload.go index d02640e2c..8669cc81b 100644 --- a/drivers/115_open/upload.go +++ b/drivers/115_open/upload.go @@ -3,7 +3,10 @@ package _115_open import ( "context" "encoding/base64" + "encoding/json" + "fmt" "io" + "strings" "time" sdk "github.com/OpenListTeam/115-sdk-go" @@ -14,8 +17,55 @@ import ( "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/avast/retry-go" + log "github.com/sirupsen/logrus" ) +type UploadCallbackResult struct { + State bool `json:"state"` + Code int `json:"code"` + Message string `json:"message"` + Data struct { + PickCode string `json:"pick_code"` + FileName string `json:"file_name"` + FileSize int64 `json:"file_size"` + FileID string `json:"file_id"` + Sha1 string `json:"sha1"` + Cid string `json:"cid"` + } `json:"data"` +} + +func checkUploadCallback(bodyBytes []byte) error { + if len(bodyBytes) == 0 { + return fmt.Errorf("115 upload callback returned empty response") + } + var result UploadCallbackResult + if err := json.Unmarshal(bodyBytes, &result); err != nil { + return fmt.Errorf("115 upload callback response parse error: %w (body: %s)", err, string(bodyBytes)) + } + if !result.State { + return fmt.Errorf("115 upload callback failed: code=%d, message=%s", result.Code, result.Message) + } + return nil +} + +// isTokenExpiredError 检测是否为OSS凭证过期错误 +func isTokenExpiredError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + return strings.Contains(errStr, "SecurityTokenExpired") || + strings.Contains(errStr, "InvalidAccessKeyId") +} + +// isPartAlreadyExistError 检测是否为分片已存在错误(超时后重试时 OSS 返回 409) +func isPartAlreadyExistError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "PartAlreadyExist") +} + func calPartSize(fileSize int64) int64 { var partSize int64 = 20 * utils.MB if fileSize > partSize { @@ -46,36 +96,29 @@ func (d *Open115) singleUpload(ctx context.Context, tempF model.File, tokenResp return err } + var bodyBytes []byte err = bucket.PutObject(initResp.Object, tempF, oss.Callback(base64.StdEncoding.EncodeToString([]byte(initResp.Callback.Value.Callback))), oss.CallbackVar(base64.StdEncoding.EncodeToString([]byte(initResp.Callback.Value.CallbackVar))), + oss.CallbackResult(&bodyBytes), ) - - return err + if err != nil { + return err + } + return checkUploadCallback(bodyBytes) } -// type CallbackResult struct { -// State bool `json:"state"` -// Code int `json:"code"` -// Message string `json:"message"` -// Data struct { -// PickCode string `json:"pick_code"` -// FileName string `json:"file_name"` -// FileSize int64 `json:"file_size"` -// FileID string `json:"file_id"` -// ThumbURL string `json:"thumb_url"` -// Sha1 string `json:"sha1"` -// Aid int `json:"aid"` -// Cid string `json:"cid"` -// } `json:"data"` -// } - func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress, tokenResp *sdk.UploadGetTokenResp, initResp *sdk.UploadInitResp) error { - ossClient, err := netutil.NewOSSClient(tokenResp.Endpoint, tokenResp.AccessKeyId, tokenResp.AccessKeySecret, oss.SecurityToken(tokenResp.SecurityToken)) - if err != nil { - return err + // 创建OSS客户端的辅助函数 + createBucket := func(token *sdk.UploadGetTokenResp) (*oss.Bucket, error) { + ossClient, err := netutil.NewOSSClient(token.Endpoint, token.AccessKeyId, token.AccessKeySecret, oss.SecurityToken(token.SecurityToken)) + if err != nil { + return nil, err + } + return ossClient.Bucket(initResp.Bucket) } - bucket, err := ossClient.Bucket(initResp.Bucket) + + bucket, err := createBucket(tokenResp) if err != nil { return err } @@ -112,6 +155,21 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, rd.Seek(0, io.SeekStart) part, err := bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, rd), partSize, int(i)) if err != nil { + if isPartAlreadyExistError(err) { + log.Infof("115 OSS part %d already exists, retrieving from ListUploadedParts", i) + lpr, listErr := bucket.ListUploadedParts(imur) + if listErr != nil { + return fmt.Errorf("part %d already exists but ListUploadedParts failed: %w", i, listErr) + } + for _, p := range lpr.UploadedParts { + if p.PartNumber == int(i) { + parts[i-1] = oss.UploadPart{PartNumber: p.PartNumber, ETag: p.ETag} + log.Infof("115 OSS part %d recovered: ETag=%s", i, p.ETag) + return nil + } + } + return fmt.Errorf("part %d reported as existing but not found in ListUploadedParts", i) + } return err } parts[i-1] = part @@ -120,7 +178,23 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, retry.Context(ctx), retry.Attempts(3), retry.DelayType(retry.BackOffDelay), - retry.Delay(time.Second)) + retry.Delay(time.Second), + retry.OnRetry(func(n uint, err error) { + if isTokenExpiredError(err) { + log.Warnf("115 OSS token expired, refreshing token...") + if newToken, refreshErr := d.client.UploadGetToken(ctx); refreshErr == nil { + tokenResp = newToken + if newBucket, bucketErr := createBucket(tokenResp); bucketErr == nil { + bucket = newBucket + log.Infof("115 OSS token refreshed successfully") + } else { + log.Errorf("Failed to create new bucket with refreshed token: %v", bucketErr) + } + } else { + log.Errorf("Failed to refresh 115 OSS token: %v", refreshErr) + } + } + })) ss.FreeSectionReader(rd) if err != nil { return err @@ -134,17 +208,16 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, up(float64(offset) * 100 / float64(fileSize)) } - // callbackRespBytes := make([]byte, 1024) + var bodyBytes []byte _, err = bucket.CompleteMultipartUpload( imur, parts, oss.Callback(base64.StdEncoding.EncodeToString([]byte(initResp.Callback.Value.Callback))), oss.CallbackVar(base64.StdEncoding.EncodeToString([]byte(initResp.Callback.Value.CallbackVar))), - // oss.CallbackResult(&callbackRespBytes), + oss.CallbackResult(&bodyBytes), ) if err != nil { return err } - - return nil + return checkUploadCallback(bodyBytes) } diff --git a/drivers/115_open/upload_test.go b/drivers/115_open/upload_test.go new file mode 100644 index 000000000..5ad2ff1a0 --- /dev/null +++ b/drivers/115_open/upload_test.go @@ -0,0 +1,29 @@ +package _115_open + +import "testing" + +func TestIsPartAlreadyExistError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"generic error", errStr("some random error"), false}, + {"timeout", errStr("net/http: timeout awaiting response headers"), false}, + {"part already exist", errStr(`oss: service returned error: StatusCode=409, ErrorCode=PartAlreadyExist, ErrorMessage="For sequential multipart upload, you can't overwrite uploaded parts."`), true}, + {"partial match", errStr("PartAlreadyExist"), true}, + {"case sensitive", errStr("partalreadyexist"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isPartAlreadyExistError(tt.err); got != tt.want { + t.Errorf("isPartAlreadyExistError() = %v, want %v", got, tt.want) + } + }) + } +} + +type errStr string + +func (e errStr) Error() string { return string(e) } diff --git a/drivers/115_share/driver.go b/drivers/115_share/driver.go index fe8b7733a..8f8cac01e 100644 --- a/drivers/115_share/driver.go +++ b/drivers/115_share/driver.go @@ -2,6 +2,7 @@ package _115_share import ( "context" + "net/http" "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" @@ -96,8 +97,12 @@ func (d *Pan115Share) Link(ctx context.Context, file model.Obj, args model.LinkA if err != nil { return nil, err } - - return &model.Link{URL: downloadInfo.URL.URL}, nil + header := http.Header{} + header.Set("User-Agent", ua) + return &model.Link{ + URL: downloadInfo.URL.URL, + Header: header, + }, nil } func (d *Pan115Share) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { diff --git a/drivers/123/util.go b/drivers/123/util.go index 5cce49b65..d53a3d9db 100644 --- a/drivers/123/util.go +++ b/drivers/123/util.go @@ -181,7 +181,7 @@ func (d *Pan123) login() error { return err } if utils.Json.Get(res.Body(), "code").ToInt() != 200 { - err = fmt.Errorf(utils.Json.Get(res.Body(), "message").ToString()) + err = fmt.Errorf("%s", utils.Json.Get(res.Body(), "message").ToString()) } else { d.AccessToken = utils.Json.Get(res.Body(), "data", "token").ToString() } diff --git a/drivers/123_open/driver.go b/drivers/123_open/driver.go index 78ff272b9..9adb1aed0 100644 --- a/drivers/123_open/driver.go +++ b/drivers/123_open/driver.go @@ -175,7 +175,7 @@ func (d *Open123) Remove(ctx context.Context, obj model.Obj) error { } func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - // 1. 创建文件 + // 1. 准备参数 // parentFileID 父目录id,上传到根目录时填写 0 parentFileId, err := strconv.ParseInt(dstDir.GetID(), 10, 64) if err != nil { @@ -197,14 +197,49 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre } } + // etag 文件md5 etag := file.GetHash().GetHash(utils.MD5) - if len(etag) < utils.MD5.Width { + + // 检查是否是可重复读取的流 + _, isSeekable := file.(*stream.SeekableStream) + + // 如果有预计算的 hash,先尝试秒传 + if len(etag) >= utils.MD5.Width { + createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) + if err != nil { + return nil, err + } + if createResp.Data.Reuse && createResp.Data.FileID != 0 { + return File{ + FileName: file.GetName(), + Size: file.GetSize(), + FileId: createResp.Data.FileID, + Type: 2, + Etag: etag, + }, nil + } + // 秒传失败,继续后续流程 + } + + if isSeekable { + // 可重复读取的流,使用 RangeRead 计算 hash,不缓存 + if len(etag) < utils.MD5.Width { + etag, err = stream.StreamHashFile(file, utils.MD5, 100, &up) + if err != nil { + return nil, err + } + } + } else { + // 不可重复读取的流(如 HTTP body) + // 秒传失败或没有 hash,缓存整个文件并计算 MD5 _, etag, err = stream.CacheFullAndHash(file, &up, utils.MD5) if err != nil { return nil, err } } + + // 2. 创建上传任务(或再次尝试秒传) createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) if err != nil { return nil, err @@ -223,13 +258,16 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre } } - // 2. 上传分片 - err = d.Upload(ctx, file, createResp, up) + // 3. 上传分片 + uploadProgress := func(p float64) { + up(40 + p*0.6) + } + err = d.Upload(ctx, file, createResp, uploadProgress) if err != nil { return nil, err } - // 3. 上传完毕 + // 4. 合并分片/完成上传 for range 60 { uploadCompleteResp, err := d.complete(createResp.Data.PreuploadID) // 返回错误代码未知,如:20103,文档也没有具体说 diff --git a/drivers/139/util.go b/drivers/139/util.go index 85c798cc7..c026de52a 100644 --- a/drivers/139/util.go +++ b/drivers/139/util.go @@ -1180,7 +1180,6 @@ func (d *Yun139) step3_third_party_login(dycpwd string) (string, error) { "x-DeviceInfo": "4|127.0.0.1|5|1.2.6|Xiaomi|23116PN5BC||02-00-00-00-00-00|android 15|1440x3200|android|||", "Content-Type": "text/plain;charset=UTF-8", "Host": "user-njs.yun.139.com", - "Connection": "Keep-Alive", "Accept-Encoding": "gzip", "User-Agent": "okhttp/3.12.2", } diff --git a/drivers/189/torrent.go b/drivers/189/torrent.go new file mode 100644 index 000000000..4e41bf490 --- /dev/null +++ b/drivers/189/torrent.go @@ -0,0 +1,149 @@ +package _189 + +import ( + "context" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "strings" + + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/torrent" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" +) + +// GenerateTorrent 根据上传过程中收集的哈希信息生成包含 CAS 扩展的 torrent 文件 +func GenerateTorrent(fileName string, fileSize int64, fileMD5 string, sliceMD5s []string, sliceSize int64, pieceHashes []byte) ([]byte, error) { + // 计算 sliceMD5 + sliceMD5 := fileMD5 + if len(sliceMD5s) > 1 { + joined := strings.Join(sliceMD5s, "\n") + sliceMD5 = strings.ToUpper(torrent.GetMD5Str(joined)) + } + + t := torrent.NewTorrent(fileName, fileSize, fileMD5) + t.Info.PieceLength = sliceSize + t.SetPieces(pieceHashes) + t.SetCASInfo(&torrent.CASInfo{ + FileMD5: fileMD5, + SliceMD5: sliceMD5, + SliceMD5s: sliceMD5s, + SliceSize: sliceSize, + Cloud: "189", + }) + + return t.Encode() +} + +// RapidUploadFromTorrent 从 torrent 文件中提取 CAS 信息进行秒传 +func (d *Cloud189) RapidUploadFromTorrent(ctx context.Context, dstDir model.Obj, torrentData []byte) error { + // 解析 torrent + t, err := torrent.Decode(torrentData) + if err != nil { + return fmt.Errorf("解析 torrent 失败: %w", err) + } + + // 检查是否包含 CAS 扩展信息 + if !t.HasCASInfo() { + return fmt.Errorf("torrent 不包含 CAS 扩展信息,无法秒传") + } + + cas := t.CAS + fileName := t.Info.Name + fileSize := t.GetTotalSize() + + // 获取 sessionKey + sessionKey, err := d.getSessionKey() + if err != nil { + return err + } + d.sessionKey = sessionKey + + // 初始化上传 + res, err := d.uploadRequest("/person/initMultiUpload", map[string]string{ + "parentFolderId": dstDir.GetID(), + "fileName": encode(fileName), + "fileSize": fmt.Sprint(fileSize), + "sliceSize": fmt.Sprint(cas.SliceSize), + "lazyCheck": "1", + }, nil) + if err != nil { + return fmt.Errorf("初始化上传失败: %w", err) + } + + uploadFileId := utils.Json.Get(res, "data", "uploadFileId").ToString() + + // 提交上传(使用 CAS 信息秒传) + _, err = d.uploadRequest("/person/commitMultiUploadFile", map[string]string{ + "uploadFileId": uploadFileId, + "fileMd5": cas.FileMD5, + "sliceMd5": cas.SliceMD5, + "lazyCheck": "1", + "opertype": "3", + }, nil) + if err != nil { + return fmt.Errorf("秒传提交失败: %w", err) + } + + return nil +} + +// ComputeTorrentFromReader 从 io.Reader 计算并生成 torrent 文件 +func ComputeTorrentFromReader(reader io.Reader, fileName string, fileSize int64, sliceSize int64) ([]byte, error) { + if sliceSize <= 0 { + sliceSize = torrent.DefaultPieceSize + } + + hw := torrent.NewHashWriter(sliceSize, sliceSize) + + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if n > 0 { + hw.Write(buf[:n]) + } + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + } + hw.Finish() + + fileMD5 := hw.GetFileMD5() + sliceMD5s := hw.GetSliceMD5s() + pieceHashes := hw.GetPieceHashes() + + return GenerateTorrent(fileName, fileSize, fileMD5, sliceMD5s, sliceSize, pieceHashes) +} + +// ComputePieceSHA1 计算单个分片的 SHA-1 哈希 +func ComputePieceSHA1(data []byte) []byte { + h := sha1.Sum(data) + return h[:] +} + +// ExtractCASFromTorrent 从 torrent 数据中提取 CAS 信息 +func ExtractCASFromTorrent(torrentData []byte) (*torrent.CASInfo, string, int64, error) { + t, err := torrent.Decode(torrentData) + if err != nil { + return nil, "", 0, fmt.Errorf("解析 torrent 失败: %w", err) + } + + if !t.HasCASInfo() { + return nil, "", 0, fmt.Errorf("torrent 不包含 CAS 扩展信息") + } + + return t.CAS, t.Info.Name, t.GetTotalSize(), nil +} + +// GetInfoHashHex 获取 torrent 的 info_hash(十六进制字符串) +func GetInfoHashHex(torrentData []byte) (string, error) { + t, err := torrent.Decode(torrentData) + if err != nil { + return "", err + } + return hex.EncodeToString(t.InfoHash), nil +} diff --git a/drivers/189/util.go b/drivers/189/util.go index bb9a6adb4..5fad92c1d 100644 --- a/drivers/189/util.go +++ b/drivers/189/util.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/md5" + sha1Pkg "crypto/sha1" "encoding/base64" "encoding/hex" "errors" @@ -18,6 +19,8 @@ import ( "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" myrand "github.com/OpenListTeam/OpenList/v4/pkg/utils/random" "github.com/go-resty/resty/v2" @@ -235,7 +238,7 @@ func (d *Cloud189) oldUpload(dstDir model.Obj, file model.FileStreamer) error { if utils.Json.Get(res.Body(), "MD5").ToString() != "" { return nil } - log.Debugf(res.String()) + log.Debugf("%s", res.String()) return errors.New(res.String()) } @@ -311,48 +314,107 @@ func (d *Cloud189) newUpload(ctx context.Context, dstDir model.Obj, file model.F } d.sessionKey = sessionKey const DEFAULT int64 = 10485760 - count := int64(math.Ceil(float64(file.GetSize()) / float64(DEFAULT))) + fileSize := file.GetSize() + count := int64(math.Ceil(float64(fileSize) / float64(DEFAULT))) - res, err := d.uploadRequest("/person/initMultiUpload", map[string]string{ + // 先计算文件完整MD5和分片MD5,用于秒传判断 + fileMd5Hex := file.GetHash().GetHash(utils.MD5) + sliceMd5Hex := "" + md5s := make([]string, 0) + + if len(fileMd5Hex) < utils.MD5.Width { + // 没有MD5,先缓存流并同时计算文件MD5和分片MD5 + fileMd5Hash := md5.New() + sliceMd5Hash := md5.New() + var finish int64 + cache, err := file.CacheFullAndWriter(nil, io.MultiWriter(fileMd5Hash, &sliceHashWriter{ + hash: sliceMd5Hash, + md5s: &md5s, + sliceSize: DEFAULT, + finish: &finish, + fileSize: fileSize, + up: up, + ctx: ctx, + })) + if err != nil { + return err + } + // 处理最后一个分片的MD5 + if finish%DEFAULT != 0 || finish == 0 { + md5s = append(md5s, strings.ToUpper(hex.EncodeToString(sliceMd5Hash.Sum(nil)))) + } + fileMd5Hex = hex.EncodeToString(fileMd5Hash.Sum(nil)) + + // seek回起始位置,供后续上传使用 + if _, err := cache.Seek(0, io.SeekStart); err != nil { + return err + } + } + + // 计算sliceMd5 + if fileSize > DEFAULT && len(md5s) > 0 { + sliceMd5Hex = utils.GetMD5EncodeStr(strings.Join(md5s, "\n")) + } else { + sliceMd5Hex = fileMd5Hex + } + + // 带fileMd5调用initMultiUpload,支持秒传 + initParams := map[string]string{ "parentFolderId": dstDir.GetID(), "fileName": encode(file.GetName()), - "fileSize": strconv.FormatInt(file.GetSize(), 10), + "fileSize": strconv.FormatInt(fileSize, 10), "sliceSize": strconv.FormatInt(DEFAULT, 10), - "lazyCheck": "1", - }, nil) + "fileMd5": fileMd5Hex, + "sliceMd5": sliceMd5Hex, + } + + res, err := d.uploadRequest("/person/initMultiUpload", initParams, nil) if err != nil { return err } uploadFileId := jsoniter.Get(res, "data", "uploadFileId").ToString() - //_, err = d.uploadRequest("/person/getUploadedPartsInfo", map[string]string{ - // "uploadFileId": uploadFileId, - //}, nil) + fileDataExists := jsoniter.Get(res, "data", "fileDataExists").ToInt() + + // 秒传成功,直接提交 + if fileDataExists == 1 { + _, err = d.uploadRequest("/person/commitMultiUploadFile", map[string]string{ + "uploadFileId": uploadFileId, + "fileMd5": fileMd5Hex, + "sliceMd5": sliceMd5Hex, + "lazyCheck": "1", + "opertype": "3", + }, nil) + return err + } + + // 非秒传,需要上传分片 var finish int64 = 0 var i int64 var byteSize int64 - md5s := make([]string, 0) - md5Sum := md5.New() + + // 额外计算 SHA-1 piece hash 用于生成 torrent + pieceSHA1Hashes := make([]byte, 0, int(count)*20) + for i = 1; i <= count; i++ { if utils.IsCanceled(ctx) { return ctx.Err() } - byteSize = file.GetSize() - finish + byteSize = fileSize - finish if DEFAULT < byteSize { byteSize = DEFAULT } - // log.Debugf("%d,%d", byteSize, finish) byteData := make([]byte, byteSize) n, err := io.ReadFull(file, byteData) - // log.Debug(err, n) if err != nil { return err } finish += int64(n) md5Bytes := getMd5(byteData) - md5Hex := hex.EncodeToString(md5Bytes) md5Base64 := base64.StdEncoding.EncodeToString(md5Bytes) - md5s = append(md5s, strings.ToUpper(md5Hex)) - md5Sum.Write(byteData) + + // 计算 SHA-1 piece hash + sha1Hash := sha1Pkg.Sum(byteData) + pieceSHA1Hashes = append(pieceSHA1Hashes, sha1Hash[:]...) var resp UploadUrlsResp res, err = d.uploadRequest("/person/getMultiUploadUrls", map[string]string{ "partInfo": fmt.Sprintf("%s-%s", strconv.FormatInt(i, 10), md5Base64), @@ -379,21 +441,58 @@ func (d *Cloud189) newUpload(ctx context.Context, dstDir model.Obj, file model.F } log.Debugf("%+v %+v", r, r.Request.Header) _ = r.Body.Close() - up(float64(i) * 100 / float64(count)) - } - fileMd5 := hex.EncodeToString(md5Sum.Sum(nil)) - sliceMd5 := fileMd5 - if file.GetSize() > DEFAULT { - sliceMd5 = utils.GetMD5EncodeStr(strings.Join(md5s, "\n")) + up(50 + float64(i)*50/float64(count)) } res, err = d.uploadRequest("/person/commitMultiUploadFile", map[string]string{ "uploadFileId": uploadFileId, - "fileMd5": fileMd5, - "sliceMd5": sliceMd5, + "fileMd5": fileMd5Hex, + "sliceMd5": sliceMd5Hex, "lazyCheck": "1", "opertype": "3", }, nil) - return err + if err != nil { + return err + } + + // 生成 torrent 文件(异步,不影响上传结果) + capturedDstDir := dstDir + capturedFileName := file.GetName() + capturedFileSize := fileSize + capturedFileMd5Hex := fileMd5Hex + capturedMd5s := md5s + go func() { + fileMD5Upper := strings.ToUpper(capturedFileMd5Hex) + torrentData, err := GenerateTorrent(capturedFileName, capturedFileSize, fileMD5Upper, capturedMd5s, DEFAULT, pieceSHA1Hashes) + if err != nil { + log.Warnf("生成 torrent 失败: %v", err) + return + } + infoHash, _ := GetInfoHashHex(torrentData) + torrentName := capturedFileName + ".cas.torrent" + log.Infof("已生成 torrent: %s (info_hash: %s, size: %d bytes)", + torrentName, infoHash, len(torrentData)) + + // 将 torrent 文件上传到同一目录 + torrentFileStream := &stream.FileStream{ + Ctx: context.Background(), + Obj: &model.Object{ + Name: torrentName, + Size: int64(len(torrentData)), + IsFolder: false, + }, + Reader: bytes.NewReader(torrentData), + Mimetype: "application/x-bittorrent", + } + uploadErr := d.oldUpload(capturedDstDir, torrentFileStream) + if uploadErr != nil { + log.Warnf("上传 torrent 文件失败: %v", uploadErr) + } else { + log.Infof("torrent 文件已上传: %s", torrentName) + op.Cache.DeleteDirectory(d, capturedDstDir.GetPath()) + } + }() + + return nil } func (d *Cloud189) getCapacityInfo(ctx context.Context) (*CapacityResp, error) { @@ -406,3 +505,52 @@ func (d *Cloud189) getCapacityInfo(ctx context.Context) (*CapacityResp, error) { } return &resp, nil } + +// sliceHashWriter 在写入过程中按分片大小自动切分并计算每个分片的MD5, +// 同时支持进度回调和取消检查。 +type sliceHashWriter struct { + hash io.Writer // 当前分片的MD5 hash + md5s *[]string // 收集每个分片的MD5十六进制字符串 + sliceSize int64 // 分片大小 + finish *int64 // 已写入的总字节数 + fileSize int64 // 文件总大小 + up driver.UpdateProgress + ctx context.Context +} + +func (w *sliceHashWriter) Write(p []byte) (int, error) { + if utils.IsCanceled(w.ctx) { + return 0, w.ctx.Err() + } + total := len(p) + written := 0 + for written < total { + // 当前分片还能写入的字节数 + sliceRemain := w.sliceSize - (*w.finish % w.sliceSize) + toWrite := int64(total - written) + if toWrite > sliceRemain { + toWrite = sliceRemain + } + n, err := w.hash.Write(p[written : written+int(toWrite)]) + if err != nil { + return written, err + } + written += n + *w.finish += int64(n) + + // 当前分片写满,记录MD5并重置 + if *w.finish%w.sliceSize == 0 { + if h, ok := w.hash.(interface{ Sum([]byte) []byte }); ok { + *w.md5s = append(*w.md5s, strings.ToUpper(hex.EncodeToString(h.Sum(nil)))) + } + if resetter, ok := w.hash.(interface{ Reset() }); ok { + resetter.Reset() + } + } + } + // 报告进度(缓存阶段占50%) + if w.fileSize > 0 && w.up != nil { + w.up(float64(*w.finish) / float64(w.fileSize) * 50) + } + return total, nil +} diff --git a/drivers/189pc/driver.go b/drivers/189pc/driver.go index af719401d..82aa1c1af 100644 --- a/drivers/189pc/driver.go +++ b/drivers/189pc/driver.go @@ -87,7 +87,12 @@ func (y *Cloud189PC) Init(ctx context.Context) (err error) { } // 先尝试用Token刷新,之后尝试登陆 - if y.Addition.RefreshToken != "" { + if y.Addition.AccessToken != "" { + y.tokenInfo = &AppSessionResp{AccessToken: y.Addition.AccessToken, RefreshToken: y.Addition.RefreshToken} + if err = y.refreshSession(); err != nil { + return err + } + } else if y.Addition.RefreshToken != "" { y.tokenInfo = &AppSessionResp{RefreshToken: y.Addition.RefreshToken} if err = y.refreshToken(); err != nil { return err @@ -288,7 +293,7 @@ func (y *Cloud189PC) Rename(ctx context.Context, srcObj model.Obj, newName strin req.SetContext(ctx).SetQueryParams(queryParam) }, nil, resp, isFamily) if err != nil { - if resp.ResCode == "FileAlreadyExists" { + if code, ok := resp.ResCode.(string); ok && code == "FileAlreadyExists" { return nil, errs.ObjectAlreadyExists } return nil, err @@ -338,6 +343,7 @@ func (y *Cloud189PC) Put(ctx context.Context, dstDir model.Obj, stream model.Fil // 响应时间长,按需启用 if y.Addition.RapidUpload && !stream.IsForceStreamUpload() { + // 尝试妙传 if newObj, err := y.RapidUpload(ctx, dstDir, stream, isFamily, overwrite); err == nil { return newObj, nil } @@ -346,10 +352,11 @@ func (y *Cloud189PC) Put(ctx context.Context, dstDir model.Obj, stream model.Fil uploadMethod := y.UploadMethod if stream.IsForceStreamUpload() { uploadMethod = "stream" - } - - // 旧版上传家庭云也有限制 - if uploadMethod == "old" { + } else if y.Addition.RapidUpload && stream.GetFile() != nil { + // 文件流支持随机读取,走FastUpload计算MD5并尝试秒传 + uploadMethod = "rapid" + } else if uploadMethod == "old" { + // 旧版上传家庭云也有限制 return y.OldUpload(ctx, dstDir, stream, up, isFamily, overwrite) } diff --git a/drivers/189pc/meta.go b/drivers/189pc/meta.go index 670b99116..3a5e02979 100644 --- a/drivers/189pc/meta.go +++ b/drivers/189pc/meta.go @@ -10,17 +10,19 @@ type Addition struct { Username string `json:"username" required:"true"` Password string `json:"password" required:"true"` VCode string `json:"validate_code"` + AccessToken string `json:"access_token" required:"false"` RefreshToken string `json:"refresh_token" help:"To switch accounts, please clear this field"` driver.RootID - OrderBy string `json:"order_by" type:"select" options:"filename,filesize,lastOpTime" default:"filename"` - OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` - Type string `json:"type" type:"select" options:"personal,family" default:"personal"` - FamilyID string `json:"family_id"` - UploadMethod string `json:"upload_method" type:"select" options:"stream,rapid,old" default:"stream"` - UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"` - FamilyTransfer bool `json:"family_transfer"` - RapidUpload bool `json:"rapid_upload"` - NoUseOcr bool `json:"no_use_ocr"` + OrderBy string `json:"order_by" type:"select" options:"filename,filesize,lastOpTime" default:"filename"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + Type string `json:"type" type:"select" options:"personal,family" default:"personal"` + FamilyID string `json:"family_id"` + UploadMethod string `json:"upload_method" type:"select" options:"stream,rapid,old" default:"stream"` + UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"` + FamilyTransfer bool `json:"family_transfer"` + RapidUpload bool `json:"rapid_upload"` + NoUseOcr bool `json:"no_use_ocr"` + GenerateTorrent bool `json:"generate_torrent" help:"Generate torrent file with CAS extension after upload"` } var config = driver.Config{ diff --git a/drivers/189pc/torrent.go b/drivers/189pc/torrent.go new file mode 100644 index 000000000..3068a0f7b --- /dev/null +++ b/drivers/189pc/torrent.go @@ -0,0 +1,296 @@ +package _189pc + +import ( + "context" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "net/url" + "strings" + + "github.com/go-resty/resty/v2" + + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/torrent" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" +) + +// GenerateTorrent 根据上传过程中收集的哈希信息生成包含 CAS 扩展的 torrent 文件 +// fileMD5: 整文件 MD5(大写十六进制) +// sliceMD5s: 每个分片的 MD5 列表(大写十六进制) +// sliceSize: 分片大小 +// pieceHashes: SHA-1 piece hashes 拼接(每 20 字节一个) +// fileName: 文件名 +// fileSize: 文件大小 +func GenerateTorrent(fileName string, fileSize int64, fileMD5 string, sliceMD5s []string, sliceSize int64, pieceHashes []byte) ([]byte, error) { + // 计算 sliceMD5 + sliceMD5 := fileMD5 + if len(sliceMD5s) > 1 { + joined := strings.Join(sliceMD5s, "\n") + sliceMD5 = strings.ToUpper(torrent.GetMD5Str(joined)) + } + + t := torrent.NewTorrent(fileName, fileSize, fileMD5) + t.Info.PieceLength = sliceSize + t.SetPieces(pieceHashes) + t.SetCASInfo(&torrent.CASInfo{ + FileMD5: fileMD5, + SliceMD5: sliceMD5, + SliceMD5s: sliceMD5s, + SliceSize: sliceSize, + Cloud: "189", + }) + + return t.Encode() +} + +// RapidUploadFromTorrent 从 torrent 文件中提取 CAS 信息进行秒传 +// 返回值:上传成功的文件对象、错误 +func (y *Cloud189PC) RapidUploadFromTorrent(ctx context.Context, dstDir model.Obj, torrentData []byte, overwrite bool) (model.Obj, error) { + isFamily := y.isFamily() + + // 解析 torrent + t, err := torrent.Decode(torrentData) + if err != nil { + return nil, fmt.Errorf("解析 torrent 失败: %w", err) + } + + // 检查是否包含 CAS 扩展信息 + if !t.HasCASInfo() { + return nil, fmt.Errorf("torrent 不包含 CAS 扩展信息,无法秒传") + } + + cas := t.CAS + fileName := t.Info.Name + fileSize := t.GetTotalSize() + + // 统一 MD5 为大写(与正常上传保持一致,天翼云盘要求大写) + fileMD5Upper := strings.ToUpper(cas.FileMD5) + + // 优先使用 torrent 中嵌入的分片大小,与生成时保持一致 + sliceSize := cas.SliceSize + if sliceSize <= 0 { + sliceSize = partSize(fileSize) + } + + // 计算 sliceMd5(与上传时一致的算法) + // 优先使用 torrent 中已有的 SliceMD5;仅当有多分片列表时才重新计算 + sliceMd5Hex := strings.ToUpper(cas.SliceMD5) + if sliceMd5Hex == "" { + sliceMd5Hex = fileMD5Upper + } + if len(cas.SliceMD5s) > 1 { + // 分片 MD5 也需要统一大写后再拼接计算 + upperSliceMD5s := make([]string, len(cas.SliceMD5s)) + for i, s := range cas.SliceMD5s { + upperSliceMD5s[i] = strings.ToUpper(s) + } + sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(upperSliceMD5s, "\n"))) + } + + + // 使用与 Web 端一致的三步秒传流程 + fullUrl := "https://upload.cloud.189.cn" + if isFamily { + fullUrl += "/family" + } else { + fullUrl += "/person" + } + + // Step 1: initMultiUpload(不传 fileMd5/sliceMd5,只传 lazyCheck) + initParams := Params{ + "parentFolderId": dstDir.GetID(), + "fileName": url.QueryEscape(fileName), + "fileSize": fmt.Sprint(fileSize), + "sliceSize": fmt.Sprint(sliceSize), + "lazyCheck": "1", + } + if isFamily { + initParams.Set("familyId", y.FamilyID) + } + + + var uploadInfo InitMultiUploadResp + _, err = y.request(fullUrl+"/initMultiUpload", "GET", func(req *resty.Request) { + req.SetContext(ctx) + }, initParams, &uploadInfo, isFamily) + if err != nil { + return nil, fmt.Errorf("initMultiUpload 失败: %w", err) + } + + + uploadFileId := uploadInfo.Data.UploadFileID + + // Step 2: checkTransSecond(用 fileMd5 + sliceMd5 + uploadFileId 检查秒传) + checkParams := Params{ + "fileMd5": fileMD5Upper, + "sliceMd5": sliceMd5Hex, + "uploadFileId": uploadFileId, + } + + + var checkResp struct { + Data struct { + FileDataExists int `json:"fileDataExists"` + } `json:"data"` + } + _, err = y.request(fullUrl+"/checkTransSecond", "GET", func(req *resty.Request) { + req.SetContext(ctx) + }, checkParams, &checkResp, isFamily) + if err != nil { + utils.Log.Errorf("[RapidUpload] checkTransSecond 失败: uploadFileId=%s, err=%v", uploadFileId, err) + return nil, fmt.Errorf("秒传检查失败: %w", err) + } + + + if checkResp.Data.FileDataExists != 1 { + return nil, fmt.Errorf("秒传失败:云端不存在该文件(fileMD5=%s, sliceMD5=%s, size=%d)", fileMD5Upper, sliceMd5Hex, fileSize) + } + + // Step 3: commitMultiUploadFile(传 fileMd5 + sliceMd5) + + var resp CommitMultiUploadFileResp + commitParams := Params{ + "uploadFileId": uploadFileId, + "fileMd5": fileMD5Upper, + "sliceMd5": sliceMd5Hex, + "lazyCheck": "1", + "opertype": IF(overwrite, "3", "1"), + } + + _, err = y.request(fullUrl+"/commitMultiUploadFile", "GET", func(req *resty.Request) { + req.SetContext(ctx) + }, commitParams, &resp, isFamily) + if err != nil { + utils.Log.Errorf("[RapidUpload] commitMultiUploadFile 失败: uploadFileId=%s, err=%v", uploadFileId, err) + return nil, fmt.Errorf("提交上传失败: %w", err) + } + + return resp.toFile(), nil +} + +// ComputeTorrentFromReader 从 io.Reader 计算并生成 torrent 文件 +// 适用于:已有文件需要生成 torrent 的场景(如下载完成后生成) +func ComputeTorrentFromReader(reader io.Reader, fileName string, fileSize int64, sliceSize int64) ([]byte, error) { + if sliceSize <= 0 { + sliceSize = torrent.DefaultPieceSize + } + + hw := torrent.NewHashWriter(sliceSize, sliceSize) + + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if n > 0 { + hw.Write(buf[:n]) + } + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + } + hw.Finish() + + fileMD5 := hw.GetFileMD5() + sliceMD5s := hw.GetSliceMD5s() + pieceHashes := hw.GetPieceHashes() + + return GenerateTorrent(fileName, fileSize, fileMD5, sliceMD5s, sliceSize, pieceHashes) +} + +// ComputePieceSHA1 计算单个分片的 SHA-1 哈希 +func ComputePieceSHA1(data []byte) []byte { + h := sha1.Sum(data) + return h[:] +} + +// ExtractCASFromTorrent 从 torrent 数据中提取 CAS 信息 +// 返回:CAS 信息、文件名、文件大小、错误 +func ExtractCASFromTorrent(torrentData []byte) (*torrent.CASInfo, string, int64, error) { + t, err := torrent.Decode(torrentData) + if err != nil { + return nil, "", 0, fmt.Errorf("解析 torrent 失败: %w", err) + } + + if !t.HasCASInfo() { + return nil, "", 0, fmt.Errorf("torrent 不包含 CAS 扩展信息") + } + + return t.CAS, t.Info.Name, t.GetTotalSize(), nil +} + +// InjectCASIntoTorrent 向已有的 torrent 文件注入 CAS 扩展信息 +// 用于:下载完成后,计算了 MD5 信息,写回到 torrent 中 +func InjectCASIntoTorrent(torrentData []byte, fileMD5 string, sliceMD5s []string, sliceSize int64) ([]byte, error) { + t, err := torrent.Decode(torrentData) + if err != nil { + return nil, fmt.Errorf("解析 torrent 失败: %w", err) + } + + // 计算 sliceMD5 + sliceMD5 := fileMD5 + if len(sliceMD5s) > 1 { + joined := strings.Join(sliceMD5s, "\n") + sliceMD5 = strings.ToUpper(torrent.GetMD5Str(joined)) + } + + // 注入 CAS 信息 + t.SetCASInfo(&torrent.CASInfo{ + FileMD5: fileMD5, + SliceMD5: sliceMD5, + SliceMD5s: sliceMD5s, + SliceSize: sliceSize, + Cloud: "189", + }) + + // 同时更新 info 中的 md5sum 字段 + if t.Info.MD5Sum == "" { + t.Info.MD5Sum = fileMD5 + } + + return t.Encode() +} + +// GetInfoHashHex 获取 torrent 的 info_hash(十六进制字符串) +func GetInfoHashHex(torrentData []byte) (string, error) { + t, err := torrent.Decode(torrentData) + if err != nil { + return "", err + } + return hex.EncodeToString(t.InfoHash), nil +} + +// ComputeSliceMD5sFromReader 从 reader 中计算每个 10MB 分片的 MD5 +// 返回:整文件 MD5、分片 MD5 列表 +func ComputeSliceMD5sFromReader(reader io.Reader, sliceSize int64) (string, []string, error) { + if sliceSize <= 0 { + sliceSize = torrent.DefaultPieceSize + } + + fileMD5Hash := utils.MD5.NewFunc() + sliceMD5s := make([]string, 0) + + buf := make([]byte, sliceSize) + for { + n, err := io.ReadFull(reader, buf) + if n > 0 { + chunk := buf[:n] + fileMD5Hash.Write(chunk) + // 计算该分片的 MD5 + sliceMD5 := strings.ToUpper(utils.HashData(utils.MD5, chunk)) + sliceMD5s = append(sliceMD5s, sliceMD5) + } + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } + if err != nil { + return "", nil, err + } + } + + fileMD5Hex := strings.ToUpper(hex.EncodeToString(fileMD5Hash.Sum(nil))) + return fileMD5Hex, sliceMD5s, nil +} diff --git a/drivers/189pc/types.go b/drivers/189pc/types.go index eed447e25..c05e3867f 100644 --- a/drivers/189pc/types.go +++ b/drivers/189pc/types.go @@ -441,7 +441,7 @@ type RenameResp struct { ParentID int64 `json:"parentId"` Rev string `json:"rev"` Size int64 `json:"size"` - ResCode string `json:"res_code"` + ResCode any `json:"res_code"` // int or string } func (r *RenameResp) toFile(f *Cloud189File) *Cloud189File { diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index 08ee658ca..86cf834e4 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -3,6 +3,7 @@ package _189pc import ( "bytes" "context" + sha1Pkg "crypto/sha1" "encoding/base64" "encoding/hex" "encoding/xml" @@ -353,9 +354,10 @@ func (y *Cloud189PC) loginByPassword() (err error) { return &erron } if tokenInfo.ResCode != 0 { - err = fmt.Errorf(tokenInfo.ResMessage) + err = fmt.Errorf("%s", tokenInfo.ResMessage) return err } + y.Addition.AccessToken = tokenInfo.AccessToken y.Addition.RefreshToken = tokenInfo.RefreshToken y.tokenInfo = &tokenInfo op.MustSaveDriverStorage(y) @@ -412,8 +414,9 @@ func (y *Cloud189PC) loginByQRCode() error { return err } if tokenInfo.ResCode != 0 { - return fmt.Errorf(tokenInfo.ResMessage) + return fmt.Errorf("%s", tokenInfo.ResMessage) } + y.Addition.AccessToken = tokenInfo.AccessToken y.Addition.RefreshToken = tokenInfo.RefreshToken y.tokenInfo = &tokenInfo op.MustSaveDriverStorage(y) @@ -661,6 +664,7 @@ func (y *Cloud189PC) refreshTokenWithRetry(retryCount int) (err error) { return y.login() } + y.Addition.AccessToken = tokenInfo.AccessToken y.Addition.RefreshToken = tokenInfo.RefreshToken y.tokenInfo = &tokenInfo op.MustSaveDriverStorage(y) @@ -739,6 +743,10 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo silceMd5 := utils.MD5.NewFunc() var writers io.Writer = silceMd5 + // 如果启用了 torrent 生成,额外计算 SHA-1 piece hash + generateTorrent := y.Addition.GenerateTorrent + pieceSHA1Hashes := make([]byte, 0, count*20) + fileMd5Hex := file.GetHash().GetHash(utils.MD5) var fileMd5 hash.Hash if len(fileMd5Hex) != utils.MD5.Width { @@ -763,7 +771,18 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo return err } silceMd5.Reset() - w, err := utils.CopyWithBuffer(writers, reader) + + // 如果需要生成 torrent,同时计算 SHA-1 + var sha1Writer hash.Hash + var multiWriter io.Writer + if generateTorrent { + sha1Writer = sha1Pkg.New() + multiWriter = io.MultiWriter(writers, sha1Writer) + } else { + multiWriter = writers + } + + w, err := utils.CopyWithBuffer(multiWriter, reader) if w != partSize { return fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", partSize, w, err) } @@ -771,6 +790,11 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo md5Bytes := silceMd5.Sum(nil) silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Bytes))) partInfo = fmt.Sprintf("%d-%s", i, base64.StdEncoding.EncodeToString(md5Bytes)) + + // 收集 SHA-1 piece hash + if generateTorrent && sha1Writer != nil { + pieceSHA1Hashes = append(pieceSHA1Hashes, sha1Writer.Sum(nil)...) + } return nil }, Do: func(ctx context.Context) (err error) { @@ -824,6 +848,45 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo if err != nil { return nil, err } + + // 生成 torrent 文件(异步,不影响上传结果) + if generateTorrent && len(pieceSHA1Hashes) > 0 { + // 捕获必要的变量 + capturedDstDir := dstDir + capturedIsFamily := isFamily + capturedFileName := file.GetName() + go func() { + torrentData, err := GenerateTorrent(capturedFileName, fileSize, fileMd5Hex, silceMd5Hexs, sliceSize, pieceSHA1Hashes) + if err != nil { + utils.Log.Warnf("生成 torrent 失败: %v", err) + return + } + infoHash, _ := GetInfoHashHex(torrentData) + torrentName := capturedFileName + ".cas.torrent" + utils.Log.Infof("已生成 torrent: %s (info_hash: %s, size: %d bytes)", + torrentName, infoHash, len(torrentData)) + + // 将 torrent 文件上传到同一目录(使用 FastUpload,因为 torrent 文件很小) + torrentFileStream := &stream.FileStream{ + Ctx: context.Background(), + Obj: &model.Object{ + Name: torrentName, + Size: int64(len(torrentData)), + IsFolder: false, + }, + Reader: bytes.NewReader(torrentData), + Mimetype: "application/x-bittorrent", + } + _, uploadErr := y.FastUpload(context.Background(), capturedDstDir, torrentFileStream, func(p float64) {}, capturedIsFamily, false) + if uploadErr != nil { + utils.Log.Warnf("上传 torrent 文件失败: %v", uploadErr) + } else { + utils.Log.Infof("torrent 文件已上传: %s", torrentName) + op.Cache.DeleteDirectory(y, capturedDstDir.GetPath()) + } + }() + } + return resp.toFile(), nil } diff --git a/drivers/alias/util.go b/drivers/alias/util.go index 8e5eb8a84..b37854394 100644 --- a/drivers/alias/util.go +++ b/drivers/alias/util.go @@ -40,7 +40,7 @@ func (d *Alias) listRoot(ctx context.Context, withDetails, refresh bool) []model if !withDetails || len(v) != 1 { continue } - remoteDriver, err := op.GetStorageByMountPath(v[0]) + remoteDriver, err := fs.GetStorage(v[0], &fs.GetStoragesArgs{}) if err != nil { continue } diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index a4a6c1de1..5f02c75f5 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -163,21 +163,29 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m } count := int(math.Ceil(float64(stream.GetSize()) / float64(partSize))) createData["part_info_list"] = makePartInfos(count) + + // 检查是否是可重复读取的流 + _, isSeekable := stream.(*streamPkg.SeekableStream) + // rapid upload rapidUpload := !stream.IsForceStreamUpload() && stream.GetSize() > 100*utils.KB && d.RapidUpload if rapidUpload { log.Debugf("[aliyundrive_open] start cal pre_hash") - // read 1024 bytes to calculate pre hash - reader, err := stream.RangeRead(http_range.Range{Start: 0, Length: 1024}) - if err != nil { - return nil, err - } - hash, err := utils.HashReader(utils.SHA1, reader) - if err != nil { - return nil, err + // 优先使用预计算的 pre_hash + preHash := stream.GetHash().GetHash(utils.PRE_HASH) + if len(preHash) != utils.PRE_HASH.Width { + // 没有预计算的 pre_hash,使用 RangeRead 计算 + reader, err := stream.RangeRead(http_range.Range{Start: 0, Length: 1024}) + if err != nil { + return nil, err + } + preHash, err = utils.HashReader(utils.SHA1, reader) + if err != nil { + return nil, err + } } createData["size"] = stream.GetSize() - createData["pre_hash"] = hash + createData["pre_hash"] = preHash } var createResp CreateResp _, err, e := d.requestReturnErrResp(ctx, limiterOther, "/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { @@ -191,9 +199,18 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m hash := stream.GetHash().GetHash(utils.SHA1) if len(hash) != utils.SHA1.Width { - _, hash, err = streamPkg.CacheFullAndHash(stream, &up, utils.SHA1) - if err != nil { - return nil, err + if isSeekable { + // 可重复读取的流,使用 StreamHashFile(RangeRead),不缓存 + hash, err = streamPkg.StreamHashFile(stream, utils.SHA1, 100, &up) + if err != nil { + return nil, err + } + } else { + // 不可重复读取的流,缓存并计算 + _, hash, err = streamPkg.CacheFullAndHash(stream, &up, utils.SHA1) + if err != nil { + return nil, err + } } } diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index fe77aca38..474dd2b98 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -1,30 +1,18 @@ package baidu_netdisk import ( - "bytes" "context" - "crypto/md5" - "encoding/hex" "errors" - "io" - "mime/multipart" - "net/http" "net/url" - "os" stdpath "path" "strconv" - "strings" "time" "github.com/OpenListTeam/OpenList/v4/drivers/base" - "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/driver" - "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" - "github.com/OpenListTeam/OpenList/v4/internal/net" - "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/avast/retry-go" log "github.com/sirupsen/logrus" ) @@ -37,6 +25,7 @@ type BaiduNetdisk struct { } var ErrUploadIDExpired = errors.New("uploadid expired") +var ErrUploadURLExpired = errors.New("upload url expired or unavailable") func (d *BaiduNetdisk) Config() driver.Config { return config @@ -199,80 +188,26 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return newObj, nil } - var ( - cache = stream.GetFile() - tmpF *os.File - err error - ) - if cache == nil { - tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return nil, err - } - defer func() { - _ = tmpF.Close() - _ = os.Remove(tmpF.Name()) - }() - cache = tmpF - } - streamSize := stream.GetSize() sliceSize := d.getSliceSize(streamSize) count := 1 if streamSize > sliceSize { count = int((streamSize + sliceSize - 1) / sliceSize) } - lastBlockSize := streamSize % sliceSize - if lastBlockSize == 0 { - lastBlockSize = sliceSize - } - - // cal md5 for first 256k data - const SliceSize int64 = 256 * utils.KB - blockList := make([]string, 0, count) - byteSize := sliceSize - fileMd5H := md5.New() - sliceMd5H := md5.New() - sliceMd5H2 := md5.New() - slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) - writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write} - if tmpF != nil { - writers = append(writers, tmpF) - } - written := int64(0) - for i := 1; i <= count; i++ { - if utils.IsCanceled(ctx) { - return nil, ctx.Err() - } - if i == count { - byteSize = lastBlockSize - } - n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize) - written += n - if err != nil && err != io.EOF { - return nil, err - } - blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil))) - sliceMd5H.Reset() - } - if tmpF != nil { - if written != streamSize { - return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize) - } - _, err = tmpF.Seek(0, io.SeekStart) - if err != nil { - return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") - } - } - contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) - sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) - blockListStr, _ := utils.Json.MarshalToString(blockList) path := stdpath.Join(dstDir.GetPath(), stream.GetName()) mtime := stream.ModTime().Unix() ctime := stream.CreateTime().Unix() - // step.1 尝试读取已保存进度 + // step.1 流式计算MD5哈希值(使用 RangeRead,不会消耗流) + contentMd5, sliceMd5, blockList, err := d.calculateHashesStream(ctx, stream, sliceSize, &up) + if err != nil { + return nil, err + } + + blockListStr, _ := utils.Json.MarshalToString(blockList) + + // step.2 尝试读取已保存进度或执行预上传 precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5) if !ok { // 没有进度,走预上传 @@ -288,6 +223,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return fileToObj(precreateResp.File), nil } } + ensureUploadURL := func() { if precreateResp.UploadURL != "" { return @@ -295,58 +231,20 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid) } - // step.2 上传分片 + // step.3 流式上传分片 + // 创建 StreamSectionReader 用于上传 + ss, err := streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err + } + uploadLoop: for range 2 { // 获取上传域名 ensureUploadURL() - // 并发上传 - threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, - retry.Attempts(UPLOAD_RETRY_COUNT), - retry.Delay(UPLOAD_RETRY_WAIT_TIME), - retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME), - retry.DelayType(retry.BackOffDelay), - retry.RetryIf(func(err error) bool { - return !errors.Is(err, ErrUploadIDExpired) - }), - retry.LastErrorOnly(true)) - - totalParts := len(precreateResp.BlockList) - - for i, partseq := range precreateResp.BlockList { - if utils.IsCanceled(upCtx) { - break - } - if partseq < 0 { - continue - } - i, partseq := i, partseq - offset, size := int64(partseq)*sliceSize, sliceSize - if partseq+1 == count { - size = lastBlockSize - } - threadG.Go(func(ctx context.Context) error { - params := map[string]string{ - "method": "upload", - "access_token": d.AccessToken, - "type": "tmpfile", - "path": path, - "uploadid": precreateResp.Uploadid, - "partseq": strconv.Itoa(partseq), - } - section := io.NewSectionReader(cache, offset, size) - err := d.uploadSlice(ctx, precreateResp.UploadURL, params, stream.GetName(), section) - if err != nil { - return err - } - precreateResp.BlockList[i] = -1 - progress := float64(threadG.Success()+1) * 100 / float64(totalParts+1) - up(progress) - return nil - }) - } - err = threadG.Wait() + // 流式并发上传 + err = d.uploadChunksStream(ctx, ss, stream, precreateResp, path, sliceSize, count, up) if err == nil { break uploadLoop } @@ -372,13 +270,19 @@ uploadLoop: precreateResp.UploadURL = "" // 覆盖掉旧的进度 base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + + // 尝试重新创建 StreamSectionReader(如果流支持重新读取) + ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err + } continue uploadLoop } return nil, err } defer up(100) - // step.3 创建文件 + // step.4 创建文件 var newFile File _, err = d.create(path, streamSize, 0, precreateResp.Uploadid, blockListStr, &newFile, mtime, ctime) if err != nil { @@ -427,68 +331,6 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in return &precreateResp, nil } -func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file *io.SectionReader) error { - b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) - mw := multipart.NewWriter(b) - _, err := mw.CreateFormFile("file", fileName) - if err != nil { - return err - } - headSize := b.Len() - err = mw.Close() - if err != nil { - return err - } - head := bytes.NewReader(b.Bytes()[:headSize]) - tail := bytes.NewReader(b.Bytes()[headSize:]) - rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, file, tail)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd) - if err != nil { - return err - } - query := req.URL.Query() - for k, v := range params { - query.Set(k, v) - } - req.URL.RawQuery = query.Encode() - req.Header.Set("Content-Type", mw.FormDataContentType()) - req.ContentLength = int64(b.Len()) + file.Size() - - client := net.NewHttpClient() - if d.UploadSliceTimeout > 0 { - client.Timeout = time.Second * time.Duration(d.UploadSliceTimeout) - } else { - client.Timeout = DEFAULT_UPLOAD_SLICE_TIMEOUT - } - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - b.Reset() - _, err = b.ReadFrom(resp.Body) - if err != nil { - return err - } - body := b.Bytes() - respStr := string(body) - log.Debugln(respStr) - lower := strings.ToLower(respStr) - // 合并 uploadid 过期检测逻辑 - if strings.Contains(lower, "uploadid") && - (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) { - return ErrUploadIDExpired - } - - errCode := utils.Json.Get(body, "error_code").ToInt() - errNo := utils.Json.Get(body, "errno").ToInt() - if errCode != 0 || errNo != 0 { - return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr) - } - return nil -} - func (d *BaiduNetdisk) GetDetails(ctx context.Context) (*model.StorageDetails, error) { du, err := d.quota(ctx) if err != nil { diff --git a/drivers/baidu_netdisk/meta.go b/drivers/baidu_netdisk/meta.go index 3f3bed022..499fcd8a8 100644 --- a/drivers/baidu_netdisk/meta.go +++ b/drivers/baidu_netdisk/meta.go @@ -31,8 +31,8 @@ type Addition struct { const ( UPLOAD_FALLBACK_API = "https://d.pcs.baidu.com" // 备用上传地址 UPLOAD_URL_EXPIRE_TIME = time.Minute * 60 // 上传地址有效期(分钟) - DEFAULT_UPLOAD_SLICE_TIMEOUT = time.Second * 60 // 上传分片请求默认超时时间 - UPLOAD_RETRY_COUNT = 3 + DEFAULT_UPLOAD_SLICE_TIMEOUT = time.Second * 180 // 上传分片请求默认超时时间(增加到3分钟以应对慢速网络) + UPLOAD_RETRY_COUNT = 5 // 增加重试次数以提高成功率 UPLOAD_RETRY_WAIT_TIME = time.Second * 1 UPLOAD_RETRY_MAX_WAIT_TIME = time.Second * 5 ) diff --git a/drivers/baidu_netdisk/upload.go b/drivers/baidu_netdisk/upload.go new file mode 100644 index 000000000..5283ffe32 --- /dev/null +++ b/drivers/baidu_netdisk/upload.go @@ -0,0 +1,311 @@ +package baidu_netdisk + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/hex" + "errors" + "io" + "mime/multipart" + "net/http" + "strconv" + "strings" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/net" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/avast/retry-go" + log "github.com/sirupsen/logrus" +) + +// calculateHashesStream 流式计算文件的MD5哈希值 +// 返回:文件MD5、前256KB的MD5、每个分片的MD5列表 +// 注意:此函数使用 RangeRead 读取数据,不会消耗流 +func (d *BaiduNetdisk) calculateHashesStream( + ctx context.Context, + stream model.FileStreamer, + sliceSize int64, + up *driver.UpdateProgress, +) (contentMd5 string, sliceMd5 string, blockList []string, err error) { + streamSize := stream.GetSize() + count := 1 + if streamSize > sliceSize { + count = int((streamSize + sliceSize - 1) / sliceSize) + } + lastBlockSize := streamSize % sliceSize + if lastBlockSize == 0 { + lastBlockSize = sliceSize + } + + // 前256KB的MD5 + const SliceSize int64 = 256 * utils.KB + blockList = make([]string, 0, count) + fileMd5H := md5.New() + sliceMd5H2 := md5.New() + sliceWritten := int64(0) + + // 使用固定大小的缓冲区进行流式哈希计算 + // 这样可以利用 readFullWithRangeRead 的链接刷新逻辑 + const chunkSize = 10 * 1024 * 1024 // 10MB per chunk + buf := make([]byte, chunkSize) + + for i := 0; i < count; i++ { + if utils.IsCanceled(ctx) { + return "", "", nil, ctx.Err() + } + + offset := int64(i) * sliceSize + length := sliceSize + if i == count-1 { + length = lastBlockSize + } + + // 计算分片MD5 + sliceMd5Calc := md5.New() + + // 分块读取并计算哈希 + var sliceOffset int64 = 0 + for sliceOffset < length { + readSize := chunkSize + if length-sliceOffset < int64(chunkSize) { + readSize = int(length - sliceOffset) + } + + // 使用 readFullWithRangeRead 读取数据,自动处理链接刷新 + n, err := streamPkg.ReadFullWithRangeRead(stream, buf[:readSize], offset+sliceOffset) + if err != nil { + return "", "", nil, err + } + + // 同时写入多个哈希计算器 + fileMd5H.Write(buf[:n]) + sliceMd5Calc.Write(buf[:n]) + if sliceWritten < SliceSize { + remaining := SliceSize - sliceWritten + if int64(n) > remaining { + sliceMd5H2.Write(buf[:remaining]) + sliceWritten += remaining + } else { + sliceMd5H2.Write(buf[:n]) + sliceWritten += int64(n) + } + } + + sliceOffset += int64(n) + } + + blockList = append(blockList, hex.EncodeToString(sliceMd5Calc.Sum(nil))) + + // 更新进度(哈希计算占总进度的一小部分) + if up != nil { + progress := float64(i+1) * 10 / float64(count) + (*up)(progress) + } + } + + return hex.EncodeToString(fileMd5H.Sum(nil)), + hex.EncodeToString(sliceMd5H2.Sum(nil)), + blockList, nil +} + +// uploadChunksStream 流式上传所有分片 +func (d *BaiduNetdisk) uploadChunksStream( + ctx context.Context, + ss streamPkg.StreamSectionReader, + stream model.FileStreamer, + precreateResp *PrecreateResp, + path string, + sliceSize int64, + count int, + up driver.UpdateProgress, +) error { + streamSize := stream.GetSize() + lastBlockSize := streamSize % sliceSize + if lastBlockSize == 0 { + lastBlockSize = sliceSize + } + + // 使用 OrderedGroup 保证 Before 阶段有序 + thread := min(d.uploadThread, len(precreateResp.BlockList)) + threadG, upCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, + retry.Attempts(UPLOAD_RETRY_COUNT), + retry.Delay(UPLOAD_RETRY_WAIT_TIME), + retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME), + retry.DelayType(retry.BackOffDelay), + retry.RetryIf(func(err error) bool { + return !errors.Is(err, ErrUploadIDExpired) + }), + retry.OnRetry(func(n uint, err error) { + // 重试前检测是否需要刷新上传 URL + if errors.Is(err, ErrUploadURLExpired) { + log.Infof("[baidu_netdisk] refreshing upload URL due to error: %v", err) + precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid) + } + }), + retry.LastErrorOnly(true)) + + totalParts := len(precreateResp.BlockList) + + for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(upCtx) { + break + } + if partseq < 0 { + continue + } + + i, partseq := i, partseq + offset := int64(partseq) * sliceSize + size := sliceSize + if partseq+1 == count { + size = lastBlockSize + } + + var reader io.ReadSeeker + + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + var err error + reader, err = ss.GetSectionReader(offset, size) + return err + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + err := d.uploadSliceStream(ctx, precreateResp.UploadURL, path, + precreateResp.Uploadid, partseq, stream.GetName(), reader, size) + if err != nil { + return err + } + precreateResp.BlockList[i] = -1 + // 进度从10%开始(前10%是哈希计算) + progress := 10 + float64(threadG.Success()+1)*90/float64(totalParts+1) + up(progress) + return nil + }, + After: func(err error) { + ss.FreeSectionReader(reader) + }, + }) + } + + return threadG.Wait() +} + +// uploadSliceStream 上传单个分片(接受io.ReadSeeker) +func (d *BaiduNetdisk) uploadSliceStream( + ctx context.Context, + uploadUrl string, + path string, + uploadid string, + partseq int, + fileName string, + reader io.ReadSeeker, + size int64, +) error { + params := map[string]string{ + "method": "upload", + "access_token": d.AccessToken, + "type": "tmpfile", + "path": path, + "uploadid": uploadid, + "partseq": strconv.Itoa(partseq), + } + + b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) + mw := multipart.NewWriter(b) + _, err := mw.CreateFormFile("file", fileName) + if err != nil { + return err + } + headSize := b.Len() + err = mw.Close() + if err != nil { + return err + } + head := bytes.NewReader(b.Bytes()[:headSize]) + tail := bytes.NewReader(b.Bytes()[headSize:]) + rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, reader, tail)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd) + if err != nil { + return err + } + query := req.URL.Query() + for k, v := range params { + query.Set(k, v) + } + req.URL.RawQuery = query.Encode() + req.Header.Set("Content-Type", mw.FormDataContentType()) + req.ContentLength = int64(b.Len()) + size + + client := net.NewHttpClient() + if d.UploadSliceTimeout > 0 { + client.Timeout = time.Second * time.Duration(d.UploadSliceTimeout) + } else { + client.Timeout = DEFAULT_UPLOAD_SLICE_TIMEOUT + } + resp, err := client.Do(req) + if err != nil { + // 检测超时或网络错误,标记需要刷新上传 URL + if isUploadURLError(err) { + log.Warnf("[baidu_netdisk] upload slice failed with network error: %v, will refresh upload URL", err) + return errors.Join(err, ErrUploadURLExpired) + } + return err + } + defer resp.Body.Close() + b.Reset() + _, err = b.ReadFrom(resp.Body) + if err != nil { + return err + } + body := b.Bytes() + respStr := string(body) + log.Debugln(respStr) + lower := strings.ToLower(respStr) + // 合并 uploadid 过期检测逻辑 + if strings.Contains(lower, "uploadid") && + (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) { + return ErrUploadIDExpired + } + + errCode := utils.Json.Get(body, "error_code").ToInt() + errNo := utils.Json.Get(body, "errno").ToInt() + if errCode != 0 || errNo != 0 { + return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr) + } + return nil +} + +// isUploadURLError 判断是否为需要刷新上传 URL 的错误 +// 包括:超时、连接被拒绝、连接重置、DNS 解析失败等网络错误 +func isUploadURLError(err error) bool { + if err == nil { + return false + } + errStr := strings.ToLower(err.Error()) + // 超时错误 + if strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "deadline exceeded") { + return true + } + // 连接错误 + if strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "no such host") || + strings.Contains(errStr, "network is unreachable") { + return true + } + // EOF 错误(连接被服务器关闭) + if strings.Contains(errStr, "eof") || + strings.Contains(errStr, "broken pipe") { + return true + } + return false +} diff --git a/drivers/baidu_netdisk/util.go b/drivers/baidu_netdisk/util.go index 0e27fb305..75018a708 100644 --- a/drivers/baidu_netdisk/util.go +++ b/drivers/baidu_netdisk/util.go @@ -207,7 +207,24 @@ func (d *BaiduNetdisk) linkOfficial(file model.Obj, _ model.LinkArgs) (*model.Li return nil, err } u := fmt.Sprintf("%s&access_token=%s", resp.List[0].Dlink, d.AccessToken) - res, err := base.NoRedirectClient.R().SetHeader("User-Agent", "pan.baidu.com").Head(u) + + // Retry HEAD request with longer timeout to avoid client-side errors + // Create a client with longer timeout (base.NoRedirectClient doesn't have timeout set) + client := base.NoRedirectClient.SetTimeout(60 * time.Second) + var res *resty.Response + maxRetries := 5 + for i := 0; i < maxRetries; i++ { + res, err = client.R(). + SetHeader("User-Agent", "pan.baidu.com"). + Head(u) + if err == nil { + break + } + if i < maxRetries-1 { + log.Warnf("HEAD request failed (attempt %d/%d): %v, retrying...", i+1, maxRetries, err) + time.Sleep(time.Duration(i+1) * 2 * time.Second) // Exponential backoff: 2s, 4s, 6s, 8s + } + } if err != nil { return nil, err } diff --git a/drivers/chaoxing/driver.go b/drivers/chaoxing/driver.go index dfd25d195..54345ff12 100644 --- a/drivers/chaoxing/driver.go +++ b/drivers/chaoxing/driver.go @@ -55,13 +55,13 @@ func (d *ChaoXing) refreshCookie() error { func (d *ChaoXing) Init(ctx context.Context) error { err := d.refreshCookie() if err != nil { - log.Errorf(ctx, err.Error()) + log.Errorf(ctx, "%s", err.Error()) } d.cron = cron.NewCron(time.Hour * 12) d.cron.Do(func() { err = d.refreshCookie() if err != nil { - log.Errorf(ctx, err.Error()) + log.Errorf(ctx, "%s", err.Error()) } }) return nil diff --git a/drivers/google_drive/driver.go b/drivers/google_drive/driver.go index 94ef854f2..0e1c04f46 100644 --- a/drivers/google_drive/driver.go +++ b/drivers/google_drive/driver.go @@ -3,17 +3,29 @@ package google_drive import ( "context" "fmt" + "io" "net/http" "strconv" + "strings" + "sync" + "time" "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/avast/retry-go" "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" ) +// mkdirLocks prevents race conditions when creating folders with the same name +// Google Drive allows duplicate folder names, so we need application-level locking +var mkdirLocks sync.Map // map[string]*sync.Mutex - key is parentID + "/" + dirName + type GoogleDrive struct { model.Storage Addition @@ -67,15 +79,79 @@ func (d *GoogleDrive) Link(ctx context.Context, file model.Obj, args model.LinkA } func (d *GoogleDrive) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + // Use per-folder lock to prevent concurrent creation of same folder + // This is critical because Google Drive allows duplicate folder names + lockKey := parentDir.GetID() + "/" + dirName + lockVal, _ := mkdirLocks.LoadOrStore(lockKey, &sync.Mutex{}) + lock := lockVal.(*sync.Mutex) + lock.Lock() + defer func() { + lock.Unlock() + mkdirLocks.Delete(lockKey) + }() + + // Check if folder already exists with retry to handle API eventual consistency + escapedDirName := strings.ReplaceAll(dirName, "'", "\\'") + query := map[string]string{ + "q": fmt.Sprintf("name='%s' and '%s' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false", escapedDirName, parentDir.GetID()), + "fields": "files(id)", + } + + var existingFiles Files + err := retry.Do(func() error { + var checkErr error + _, checkErr = d.request("https://www.googleapis.com/drive/v3/files", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &existingFiles) + return checkErr + }, + retry.Context(ctx), + retry.Attempts(3), + retry.DelayType(retry.BackOffDelay), + retry.Delay(200*time.Millisecond), + ) + + // If query succeeded and folder exists, return success (idempotent) + if err == nil && len(existingFiles.Files) > 0 { + log.Debugf("[google_drive] Folder '%s' already exists in parent %s, skipping creation", dirName, parentDir.GetID()) + return nil + } + // If query failed, return error to prevent duplicate creation + if err != nil { + return fmt.Errorf("failed to check existing folder '%s': %w", dirName, err) + } + + // Create new folder (only when confirmed folder doesn't exist) data := base.Json{ "name": dirName, "parents": []string{parentDir.GetID()}, "mimeType": "application/vnd.google-apps.folder", } - _, err := d.request("https://www.googleapis.com/drive/v3/files", http.MethodPost, func(req *resty.Request) { - req.SetBody(data) - }, nil) - return err + + var createErr error + err = retry.Do(func() error { + _, createErr = d.request("https://www.googleapis.com/drive/v3/files", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return createErr + }, + retry.Context(ctx), + retry.Attempts(3), + retry.DelayType(retry.BackOffDelay), + retry.Delay(500*time.Millisecond), + ) + + if err != nil { + return err + } + + // Wait for API eventual consistency before releasing lock + // This helps prevent race conditions where a concurrent request + // checks for folder existence before the newly created folder is visible + // 500ms is needed because Google Drive API has significant sync delay + time.Sleep(500 * time.Millisecond) + + return nil } func (d *GoogleDrive) Move(ctx context.Context, srcObj, dstDir model.Obj) error { @@ -111,8 +187,50 @@ func (d *GoogleDrive) Remove(ctx context.Context, obj model.Obj) error { return err } -func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - obj := stream.GetExist() +const maxPutAuthRetries = 2 + +func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + return d.putWithRetry(ctx, dstDir, file, up, 0) +} + +func (d *GoogleDrive) putWithRetry(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, authRetries int) error { + // 1. 准备MD5(用于完整性校验) + md5Hash := file.GetHash().GetHash(utils.MD5) + + // 检查是否是可重复读取的流 + _, isSeekable := file.(*stream.SeekableStream) + + if isSeekable { + // 可重复读取的流,使用 RangeRead 计算 hash,不缓存 + if len(md5Hash) != utils.MD5.Width { + var err error + md5Hash, err = stream.StreamHashFile(file, utils.MD5, 100, &up) + if err != nil { + return err + } + _ = md5Hash // MD5用于后续完整性校验(Google Drive会自动校验) + } + } else { + // 不可重复读取的流(如 HTTP body) + if len(md5Hash) != utils.MD5.Width { + // 缓存整个文件并计算 MD5 + var err error + _, md5Hash, err = stream.CacheFullAndHash(file, &up, utils.MD5) + if err != nil { + return err + } + _ = md5Hash // MD5用于后续完整性校验 + } else if file.GetFile() == nil { + // 有 MD5 但没有缓存,需要缓存以支持后续 RangeRead + _, err := file.CacheFullAndWriter(&up, nil) + if err != nil { + return err + } + } + } + + // 2. 初始化可恢复上传会话 + obj := file.GetExist() var ( e Error url string @@ -125,7 +243,7 @@ func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.Fi data = base.Json{} } else { data = base.Json{ - "name": stream.GetName(), + "name": file.GetName(), "parents": []string{dstDir.GetID()}, } url = "https://www.googleapis.com/upload/drive/v3/files?uploadType=resumable&supportsAllDrives=true" @@ -133,8 +251,8 @@ func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.Fi req := base.NoRedirectClient.R(). SetHeaders(map[string]string{ "Authorization": "Bearer " + d.AccessToken, - "X-Upload-Content-Type": stream.GetMimetype(), - "X-Upload-Content-Length": strconv.FormatInt(stream.GetSize(), 10), + "X-Upload-Content-Type": file.GetMimetype(), + "X-Upload-Content-Length": strconv.FormatInt(file.GetSize(), 10), }). SetError(&e).SetBody(data).SetContext(ctx) if obj != nil { @@ -146,25 +264,45 @@ func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.Fi return err } if e.Error.Code != 0 { - if e.Error.Code == 401 { + if e.Error.Code == 401 && authRetries < maxPutAuthRetries { err = d.refreshToken() if err != nil { return err } - return d.Put(ctx, dstDir, stream, up) + return d.putWithRetry(ctx, dstDir, file, up, authRetries+1) } return fmt.Errorf("%s: %v", e.Error.Message, e.Error.Errors) } + + // 3. 上传文件内容 putUrl := res.Header().Get("location") - if stream.GetSize() < d.ChunkSize*1024*1024 { - _, err = d.request(putUrl, http.MethodPut, func(req *resty.Request) { - req.SetHeader("Content-Length", strconv.FormatInt(stream.GetSize(), 10)). - SetBody(driver.NewLimitedUploadStream(ctx, stream)) - }, nil) + if file.GetSize() < d.ChunkSize*1024*1024 { + // 小文件上传:使用 RangeRead 读取整个文件(避免消费已计算hash的stream) + err = retry.Do(func() error { + reader, err := file.RangeRead(http_range.Range{Start: 0, Length: file.GetSize()}) + if err != nil { + return err + } + if closer, ok := reader.(io.Closer); ok { + defer closer.Close() + } + + _, err = d.request(putUrl, http.MethodPut, func(req *resty.Request) { + req.SetHeader("Content-Length", strconv.FormatInt(file.GetSize(), 10)). + SetBody(driver.NewLimitedUploadStream(ctx, reader)) + }, nil) + return err + }, + retry.Context(ctx), + retry.Attempts(3), + retry.DelayType(retry.BackOffDelay), + retry.Delay(time.Second), + ) + return err } else { - err = d.chunkUpload(ctx, stream, putUrl, up) + // 大文件分片上传 + return d.chunkUpload(ctx, file, putUrl, up) } - return err } func (d *GoogleDrive) GetDetails(ctx context.Context) (*model.StorageDetails, error) { diff --git a/drivers/google_drive/driver_test.go b/drivers/google_drive/driver_test.go new file mode 100644 index 000000000..b98bdcf14 --- /dev/null +++ b/drivers/google_drive/driver_test.go @@ -0,0 +1,71 @@ +package google_drive + +import ( + "sync" + "testing" +) + +func TestMaxPutAuthRetriesIsBounded(t *testing.T) { + if maxPutAuthRetries < 1 { + t.Fatalf("maxPutAuthRetries=%d, must be >= 1", maxPutAuthRetries) + } + if maxPutAuthRetries > 5 { + t.Fatalf("maxPutAuthRetries=%d, must be <= 5 to prevent excessive retries", maxPutAuthRetries) + } +} + +func TestMkdirLocksCleanedUpAfterUse(t *testing.T) { + // Reset state + mkdirLocks = sync.Map{} + + key1 := "parent-1/folder-a" + key2 := "parent-2/folder-b" + + // Simulate two MakeDir calls storing locks + mkdirLocks.LoadOrStore(key1, &sync.Mutex{}) + mkdirLocks.LoadOrStore(key2, &sync.Mutex{}) + + // Both should exist + if _, ok := mkdirLocks.Load(key1); !ok { + t.Fatal("key1 should exist") + } + if _, ok := mkdirLocks.Load(key2); !ok { + t.Fatal("key2 should exist") + } + + // After MakeDir completes, locks should be cleaned up + mkdirLocks.Delete(key1) + mkdirLocks.Delete(key2) + + if _, ok := mkdirLocks.Load(key1); ok { + t.Fatal("key1 should be deleted after cleanup") + } + if _, ok := mkdirLocks.Load(key2); ok { + t.Fatal("key2 should be deleted after cleanup") + } +} + +func TestMkdirLocksNoCrossContamination(t *testing.T) { + mkdirLocks = sync.Map{} + + key := "parent/shared-folder" + lockVal, _ := mkdirLocks.LoadOrStore(key, &sync.Mutex{}) + lock := lockVal.(*sync.Mutex) + + // Simulate concurrent access: lock should be shared for same key + lockVal2, loaded := mkdirLocks.LoadOrStore(key, &sync.Mutex{}) + if !loaded { + t.Fatal("second LoadOrStore should return existing entry") + } + lock2 := lockVal2.(*sync.Mutex) + if lock != lock2 { + t.Fatal("same key should return same mutex instance") + } + + // Different key should get different lock + lockVal3, _ := mkdirLocks.LoadOrStore("other-parent/other-folder", &sync.Mutex{}) + lock3 := lockVal3.(*sync.Mutex) + if lock == lock3 { + t.Fatal("different keys should have different mutex instances") + } +} diff --git a/drivers/google_drive/util.go b/drivers/google_drive/util.go index 042abafa4..53695deee 100644 --- a/drivers/google_drive/util.go +++ b/drivers/google_drive/util.go @@ -170,7 +170,7 @@ func (d *GoogleDrive) refreshToken() error { } log.Debug(res.String()) if e.Error != "" { - return fmt.Errorf(e.Error) + return fmt.Errorf("%s", e.Error) } d.AccessToken = resp.AccessToken return nil @@ -192,7 +192,7 @@ func (d *GoogleDrive) refreshToken() error { } log.Debug(res.String()) if e.Error != "" { - return fmt.Errorf(e.Error) + return fmt.Errorf("%s", e.Error) } d.AccessToken = resp.AccessToken return nil @@ -296,9 +296,60 @@ func (d *GoogleDrive) getFiles(id string) ([]File, error) { res = append(res, resp.Files...) } + + // Handle duplicate filenames by adding suffixes like (1), (2), etc. + // Google Drive allows multiple files with the same name in one folder, + // but OpenList uses path-based file system which requires unique names + res = handleDuplicateNames(res) + return res, nil } +// handleDuplicateNames adds suffixes to duplicate filenames to make them unique +// For example: file.txt, file (1).txt, file (2).txt +func handleDuplicateNames(files []File) []File { + if len(files) <= 1 { + return files + } + + // Track how many files with each name we've seen + nameCount := make(map[string]int) + + // First pass: count occurrences of each name + for _, file := range files { + nameCount[file.Name]++ + } + + // Second pass: add suffixes to duplicates + nameIndex := make(map[string]int) + for i := range files { + name := files[i].Name + if nameCount[name] > 1 { + index := nameIndex[name] + nameIndex[name]++ + + if index > 0 { + // Add suffix for all except the first occurrence + // Split name into base and extension + ext := "" + base := name + for j := len(name) - 1; j >= 0; j-- { + if name[j] == '.' { + ext = name[j:] + base = name[:j] + break + } + } + + // Add (1), (2), etc. suffix + files[i].Name = fmt.Sprintf("%s (%d)%s", base, index, ext) + } + } + } + + return files +} + // getTargetFileInfo gets target file details for shortcuts func (d *GoogleDrive) getTargetFileInfo(targetId string) (File, error) { var targetFile File diff --git a/drivers/google_photo/util.go b/drivers/google_photo/util.go index 3a9b66ab2..81d05b492 100644 --- a/drivers/google_photo/util.go +++ b/drivers/google_photo/util.go @@ -32,7 +32,7 @@ func (d *GooglePhoto) refreshToken() error { return err } if e.Error != "" { - return fmt.Errorf(e.Error) + return fmt.Errorf("%s", e.Error) } d.AccessToken = resp.AccessToken return nil diff --git a/drivers/lanzou/util.go b/drivers/lanzou/util.go index 9a15d428e..37844500e 100644 --- a/drivers/lanzou/util.go +++ b/drivers/lanzou/util.go @@ -91,7 +91,7 @@ func (d *LanZou) _post(url string, callback base.ReqCallback, resp interface{}, if info == "" { info = utils.Json.Get(data, "info").ToString() } - return data, fmt.Errorf(info) + return data, fmt.Errorf("%s", info) } } diff --git a/drivers/openlist/driver.go b/drivers/openlist/driver.go index 79fc51185..b37d72a06 100644 --- a/drivers/openlist/driver.go +++ b/drivers/openlist/driver.go @@ -14,6 +14,8 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/go-resty/resty/v2" @@ -195,6 +197,92 @@ func (d *OpenList) Remove(ctx context.Context, obj model.Obj) error { } func (d *OpenList) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { + // 预计算 hash(如果不存在),使用 RangeRead 不消耗 Reader + // 这样远端驱动不需要再计算,避免 HTTP body 被重复读取 + md5Hash := s.GetHash().GetHash(utils.MD5) + sha1Hash := s.GetHash().GetHash(utils.SHA1) + sha256Hash := s.GetHash().GetHash(utils.SHA256) + sha1_128kHash := s.GetHash().GetHash(utils.SHA1_128K) + preHash := s.GetHash().GetHash(utils.PRE_HASH) + + // 计算所有缺失的 hash,确保最大兼容性 + if len(md5Hash) != utils.MD5.Width { + var err error + md5Hash, err = stream.StreamHashFile(s, utils.MD5, 33, &up) + if err != nil { + log.Warnf("[openlist] failed to pre-calculate MD5: %v", err) + md5Hash = "" + } + } + if len(sha1Hash) != utils.SHA1.Width { + var err error + sha1Hash, err = stream.StreamHashFile(s, utils.SHA1, 33, &up) + if err != nil { + log.Warnf("[openlist] failed to pre-calculate SHA1: %v", err) + sha1Hash = "" + } + } + if len(sha256Hash) != utils.SHA256.Width { + var err error + sha256Hash, err = stream.StreamHashFile(s, utils.SHA256, 34, &up) + if err != nil { + log.Warnf("[openlist] failed to pre-calculate SHA256: %v", err) + sha256Hash = "" + } + } + + // 计算特殊 hash(用于秒传验证) + // SHA1_128K: 前128KB的SHA1,115网盘使用 + if len(sha1_128kHash) != utils.SHA1_128K.Width { + const PreHashSize int64 = 128 * 1024 // 128KB + hashSize := PreHashSize + if s.GetSize() < PreHashSize { + hashSize = s.GetSize() + } + reader, err := s.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err == nil { + sha1_128kHash, err = utils.HashReader(utils.SHA1, reader) + if closer, ok := reader.(io.Closer); ok { + _ = closer.Close() + } + if err != nil { + log.Warnf("[openlist] failed to pre-calculate SHA1_128K: %v", err) + sha1_128kHash = "" + } + } else { + log.Warnf("[openlist] failed to RangeRead for SHA1_128K: %v", err) + } + } + + // PRE_HASH: 前1024字节的SHA1,阿里云盘使用 + if len(preHash) != utils.PRE_HASH.Width { + const PreHashSize int64 = 1024 // 1KB + hashSize := PreHashSize + if s.GetSize() < PreHashSize { + hashSize = s.GetSize() + } + reader, err := s.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err == nil { + preHash, err = utils.HashReader(utils.SHA1, reader) + if closer, ok := reader.(io.Closer); ok { + _ = closer.Close() + } + if err != nil { + log.Warnf("[openlist] failed to pre-calculate PRE_HASH: %v", err) + preHash = "" + } + } else { + log.Warnf("[openlist] failed to RangeRead for PRE_HASH: %v", err) + } + } + + // 诊断日志:检查流的状态 + if ss, ok := s.(*stream.SeekableStream); ok { + if ss.Reader != nil { + log.Warnf("[openlist] WARNING: SeekableStream.Reader is not nil for file %s, stream may have been consumed!", s.GetName()) + } + } + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, @@ -206,14 +294,20 @@ func (d *OpenList) Put(ctx context.Context, dstDir model.Obj, s model.FileStream req.Header.Set("Authorization", d.Token) req.Header.Set("File-Path", path.Join(dstDir.GetPath(), s.GetName())) req.Header.Set("Password", d.MetaPassword) - if md5 := s.GetHash().GetHash(utils.MD5); len(md5) > 0 { - req.Header.Set("X-File-Md5", md5) + if len(md5Hash) > 0 { + req.Header.Set("X-File-Md5", md5Hash) + } + if len(sha1Hash) > 0 { + req.Header.Set("X-File-Sha1", sha1Hash) + } + if len(sha256Hash) > 0 { + req.Header.Set("X-File-Sha256", sha256Hash) } - if sha1 := s.GetHash().GetHash(utils.SHA1); len(sha1) > 0 { - req.Header.Set("X-File-Sha1", sha1) + if len(sha1_128kHash) > 0 { + req.Header.Set("X-File-Sha1-128k", sha1_128kHash) } - if sha256 := s.GetHash().GetHash(utils.SHA256); len(sha256) > 0 { - req.Header.Set("X-File-Sha256", sha256) + if len(preHash) > 0 { + req.Header.Set("X-File-Pre-Hash", preHash) } req.ContentLength = s.GetSize() diff --git a/drivers/quark_open/driver.go b/drivers/quark_open/driver.go index f0b8baf09..26a4288cc 100644 --- a/drivers/quark_open/driver.go +++ b/drivers/quark_open/driver.go @@ -8,6 +8,7 @@ import ( "hash" "io" "net/http" + "strings" "time" "github.com/OpenListTeam/OpenList/v4/drivers/base" @@ -18,15 +19,23 @@ import ( "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" ) type QuarkOpen struct { model.Storage Addition - config driver.Config - conf Conf + config driver.Config + conf Conf + limiter *rate.Limiter } +// 速率限制常量:夸克开放平台限流,保守设置 +const ( + quarkRateLimit = 2.0 // 每秒2个请求,避免限流 +) + func (d *QuarkOpen) Config() driver.Config { return d.config } @@ -36,6 +45,9 @@ func (d *QuarkOpen) GetAddition() driver.Additional { } func (d *QuarkOpen) Init(ctx context.Context) error { + // 初始化速率限制器 + d.limiter = rate.NewLimiter(rate.Limit(quarkRateLimit), 1) + var resp UserInfoResp _, err := d.request(ctx, "/open/v1/user/info", http.MethodGet, nil, &resp) @@ -52,11 +64,22 @@ func (d *QuarkOpen) Init(ctx context.Context) error { return err } +// waitLimit 等待速率限制 +func (d *QuarkOpen) waitLimit(ctx context.Context) error { + if d.limiter != nil { + return d.limiter.Wait(ctx) + } + return nil +} + func (d *QuarkOpen) Drop(ctx context.Context) error { return nil } func (d *QuarkOpen) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.waitLimit(ctx); err != nil { + return nil, err + } files, err := d.GetFiles(ctx, dir.GetID()) if err != nil { return nil, err @@ -67,6 +90,9 @@ func (d *QuarkOpen) List(ctx context.Context, dir model.Obj, args model.ListArgs } func (d *QuarkOpen) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.waitLimit(ctx); err != nil { + return nil, err + } data := base.Json{ "fid": file.GetID(), } @@ -143,35 +169,116 @@ func (d *QuarkOpen) Remove(ctx context.Context, obj model.Obj) error { } func (d *QuarkOpen) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - md5Str, sha1Str := stream.GetHash().GetHash(utils.MD5), stream.GetHash().GetHash(utils.SHA1) - var ( - md5 hash.Hash - sha1 hash.Hash - ) - writers := []io.Writer{} - if len(md5Str) != utils.MD5.Width { - md5 = utils.MD5.NewFunc() - writers = append(writers, md5) - } - if len(sha1Str) != utils.SHA1.Width { - sha1 = utils.SHA1.NewFunc() - writers = append(writers, sha1) + if err := d.waitLimit(ctx); err != nil { + return err } + md5Str, sha1Str := stream.GetHash().GetHash(utils.MD5), stream.GetHash().GetHash(utils.SHA1) - if len(writers) > 0 { - _, err := stream.CacheFullAndWriter(&up, io.MultiWriter(writers...)) - if err != nil { - return err - } - if md5 != nil { - md5Str = hex.EncodeToString(md5.Sum(nil)) - } - if sha1 != nil { - sha1Str = hex.EncodeToString(sha1.Sum(nil)) + // 检查是否需要计算hash + needMD5 := len(md5Str) != utils.MD5.Width + needSHA1 := len(sha1Str) != utils.SHA1.Width + + if needMD5 || needSHA1 { + // 检查是否为可重复读取的流 + _, isSeekable := stream.(*streamPkg.SeekableStream) + + if isSeekable { + // 可重复读取的流,使用 RangeRead 一次性计算所有hash,避免重复读取 + var md5 hash.Hash + var sha1 hash.Hash + writers := []io.Writer{} + + if needMD5 { + md5 = utils.MD5.NewFunc() + writers = append(writers, md5) + } + if needSHA1 { + sha1 = utils.SHA1.NewFunc() + writers = append(writers, sha1) + } + + // 使用 RangeRead 分块读取文件,同时计算多个hash + multiWriter := io.MultiWriter(writers...) + size := stream.GetSize() + chunkSize := int64(10 * utils.MB) // 10MB per chunk + buf := make([]byte, chunkSize) + var offset int64 = 0 + + for offset < size { + readSize := min(chunkSize, size-offset) + + n, err := streamPkg.ReadFullWithRangeRead(stream, buf[:readSize], offset) + if err != nil { + return fmt.Errorf("calculate hash failed at offset %d: %w", offset, err) + } + + multiWriter.Write(buf[:n]) + offset += int64(n) + + // 更新进度(hash计算占用40%的进度) + up(40 * float64(offset) / float64(size)) + } + + if md5 != nil { + md5Str = hex.EncodeToString(md5.Sum(nil)) + } + if sha1 != nil { + sha1Str = hex.EncodeToString(sha1.Sum(nil)) + } + } else { + // 不可重复读取的流(如网络流),需要缓存并计算hash + var md5 hash.Hash + var sha1 hash.Hash + writers := []io.Writer{} + + if needMD5 { + md5 = utils.MD5.NewFunc() + writers = append(writers, md5) + } + if needSHA1 { + sha1 = utils.SHA1.NewFunc() + writers = append(writers, sha1) + } + + _, err := stream.CacheFullAndWriter(&up, io.MultiWriter(writers...)) + if err != nil { + return err + } + + if md5 != nil { + md5Str = hex.EncodeToString(md5.Sum(nil)) + } + if sha1 != nil { + sha1Str = hex.EncodeToString(sha1.Sum(nil)) + } } } - // pre - pre, err := d.upPre(ctx, stream, dstDir.GetID(), md5Str, sha1Str) + // pre - 带有 proof fail 重试逻辑 + var pre UpPreResp + var err error + err = retry.Do(func() error { + var preErr error + pre, preErr = d.upPre(ctx, stream, dstDir.GetID(), md5Str, sha1Str) + if preErr != nil { + // 检查是否为 proof fail 错误 + if strings.Contains(preErr.Error(), "proof") || strings.Contains(preErr.Error(), "43010") { + log.Warnf("[quark_open] Proof verification failed, retrying: %v", preErr) + return preErr // 返回错误触发重试 + } + // 检查是否为限流错误 + if strings.Contains(preErr.Error(), "限流") || strings.Contains(preErr.Error(), "rate") { + log.Warnf("[quark_open] Rate limited, waiting before retry: %v", preErr) + time.Sleep(2 * time.Second) // 额外等待 + return preErr + } + } + return preErr + }, + retry.Context(ctx), + retry.Attempts(3), + retry.DelayType(retry.BackOffDelay), + retry.Delay(500*time.Millisecond), + ) if err != nil { return err } @@ -181,16 +288,70 @@ func (d *QuarkOpen) Put(ctx context.Context, dstDir model.Obj, stream model.File return nil } - // get part info - partInfo := d._getPartInfo(stream, pre.Data.PartSize) - // get upload url info - upUrlInfo, err := d.upUrl(ctx, pre, partInfo) - if err != nil { + // 空文件特殊处理:跳过分片上传,直接调用 upFinish + // 由于夸克 API 对空文件处理不稳定,尝试完成上传,失败则直接成功返回 + if stream.GetSize() == 0 { + log.Infof("[quark_open] Empty file detected, attempting direct finish (task_id: %s)", pre.Data.TaskID) + err = d.upFinish(ctx, pre, []base.Json{}, []string{}) + if err != nil { + // 空文件 upFinish 失败,可能是 API 不支持,直接视为成功 + log.Warnf("[quark_open] Empty file upFinish failed: %v, treating as success", err) + } + up(100) + return nil + } + + // 带重试的分片大小调整逻辑:如果检测到 "part list exceed" 错误,自动翻倍分片大小 + var upUrlInfo UpUrlInfo + var partInfo []base.Json + currentPartSize := pre.Data.PartSize + const maxRetries = 5 + const maxPartSize = 1024 * utils.MB // 1GB 上限 + + for attempt := 0; attempt < maxRetries; attempt++ { + // 计算分片信息 + partInfo = d._getPartInfo(stream, currentPartSize) + + // 尝试获取上传 URL + upUrlInfo, err = d.upUrl(ctx, pre, partInfo) + if err == nil { + // 成功获取上传 URL + log.Infof("[quark_open] Successfully obtained upload URLs with part size: %d MB (%d parts)", + currentPartSize/(1024*1024), len(partInfo)) + break + } + + // 检查是否为分片超限错误 + if strings.Contains(err.Error(), "exceed") { + if attempt < maxRetries-1 { + // 还有重试机会,翻倍分片大小 + newPartSize := currentPartSize * 2 + + // 检查是否超过上限 + if newPartSize > maxPartSize { + return fmt.Errorf("part list exceeded and cannot increase part size (current: %d MB, max: %d MB). File may be too large for Quark API", + currentPartSize/(1024*1024), maxPartSize/(1024*1024)) + } + + log.Warnf("[quark_open] Part list exceeded (attempt %d/%d, %d parts). Retrying with doubled part size: %d MB -> %d MB", + attempt+1, maxRetries, len(partInfo), + currentPartSize/(1024*1024), newPartSize/(1024*1024)) + + currentPartSize = newPartSize + continue // 重试 + } else { + // 已达到最大重试次数 + return fmt.Errorf("part list exceeded after %d retries. Last attempt: part size %d MB, %d parts", + maxRetries, currentPartSize/(1024*1024), len(partInfo)) + } + } + + // 其他错误,直接返回 return err } - // part up - ss, err := streamPkg.NewStreamSectionReader(stream, int(pre.Data.PartSize), &up) + // part up - 使用调整后的 currentPartSize + ss, err := streamPkg.NewStreamSectionReader(stream, int(currentPartSize), &up) if err != nil { return err } @@ -204,30 +365,49 @@ func (d *QuarkOpen) Put(ctx context.Context, dstDir model.Obj, stream model.File return ctx.Err() } - offset := int64(i) * pre.Data.PartSize - size := min(pre.Data.PartSize, total-offset) + offset := int64(i) * currentPartSize + size := min(currentPartSize, total-offset) rd, err := ss.GetSectionReader(offset, size) if err != nil { return err } + + // 上传重试逻辑,包含URL刷新 + var etag string err = retry.Do(func() error { rd.Seek(0, io.SeekStart) - etag, err := d.upPart(ctx, upUrlInfo, i, driver.NewLimitedUploadStream(ctx, rd)) - if err != nil { - return err + var uploadErr error + etag, uploadErr = d.upPart(ctx, upUrlInfo, i, driver.NewLimitedUploadStream(ctx, rd)) + + // 检查是否为URL过期错误 + if uploadErr != nil && strings.Contains(uploadErr.Error(), "expire") { + log.Warnf("[quark_open] Upload URL expired for part %d, refreshing...", i) + // 刷新上传URL + newUpUrlInfo, refreshErr := d.upUrl(ctx, pre, partInfo) + if refreshErr != nil { + return fmt.Errorf("failed to refresh upload url: %w", refreshErr) + } + upUrlInfo = newUpUrlInfo + log.Infof("[quark_open] Upload URL refreshed successfully") + + // 使用新URL重试上传 + rd.Seek(0, io.SeekStart) + etag, uploadErr = d.upPart(ctx, upUrlInfo, i, driver.NewLimitedUploadStream(ctx, rd)) } - etags = append(etags, etag) - return nil + + return uploadErr }, retry.Context(ctx), retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second)) + ss.FreeSectionReader(rd) if err != nil { return fmt.Errorf("failed to upload part %d: %w", i, err) } + etags = append(etags, etag) up(95 * float64(offset+size) / float64(total)) } diff --git a/drivers/quark_open/meta.go b/drivers/quark_open/meta.go index 3527b52e9..ee1903939 100644 --- a/drivers/quark_open/meta.go +++ b/drivers/quark_open/meta.go @@ -13,8 +13,8 @@ type Addition struct { APIAddress string `json:"api_url_address" default:"https://api.oplist.org/quarkyun/renewapi"` AccessToken string `json:"access_token" required:"false" default:""` RefreshToken string `json:"refresh_token" required:"true"` - AppID string `json:"app_id" required:"true" help:"Keep it empty if you don't have one"` - SignKey string `json:"sign_key" required:"true" help:"Keep it empty if you don't have one"` + AppID string `json:"app_id" required:"false" default:"" help:"Optional - Auto-filled from online API, or use your own"` + SignKey string `json:"sign_key" required:"false" default:"" help:"Optional - Auto-filled from online API, or use your own"` } type Conf struct { diff --git a/drivers/quark_open/util.go b/drivers/quark_open/util.go index 788ca0e99..1a3058375 100644 --- a/drivers/quark_open/util.go +++ b/drivers/quark_open/util.go @@ -20,6 +20,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" ) @@ -283,8 +284,15 @@ func (d *QuarkOpen) getProofRange(proofSeed string, fileSize int64) (*ProofRange func (d *QuarkOpen) _getPartInfo(stream model.FileStreamer, partSize int64) []base.Json { // 计算分片信息 - partInfo := make([]base.Json, 0) total := stream.GetSize() + + // 确保partSize合理:最小4MB,避免分片过多 + const minPartSize int64 = 4 * utils.MB + if partSize < minPartSize { + partSize = minPartSize + } + + partInfo := make([]base.Json, 0) left := total partNumber := 1 @@ -304,6 +312,7 @@ func (d *QuarkOpen) _getPartInfo(stream model.FileStreamer, partSize int64) []ba partNumber++ } + log.Infof("[quark_open] Upload plan: file_size=%d, part_size=%d, part_count=%d", total, partSize, len(partInfo)) return partInfo } @@ -315,11 +324,17 @@ func (d *QuarkOpen) upUrl(ctx context.Context, pre UpPreResp, partInfo []base.Js } var resp UpUrlResp + log.Infof("[quark_open] Requesting upload URLs for %d parts (task_id: %s)", len(partInfo), pre.Data.TaskID) + _, err = d.request(ctx, "/open/v1/file/get_upload_urls", http.MethodPost, func(req *resty.Request) { req.SetBody(data) }, &resp) if err != nil { + // 如果是分片超限错误,记录详细信息 + if strings.Contains(err.Error(), "part list exceed") { + log.Errorf("[quark_open] Part list exceeded limit! Requested %d parts. Please check Quark API documentation for actual limit.", len(partInfo)) + } return upUrlInfo, err } @@ -340,13 +355,43 @@ func (d *QuarkOpen) upPart(ctx context.Context, upUrlInfo UpUrlInfo, partNumber req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("User-Agent", "Go-http-client/1.1") + // ✅ 关键修复:使用更长的超时时间(10分钟) + // 慢速网络下大文件分片上传可能需要很长时间 + client := &http.Client{ + Timeout: 10 * time.Minute, + Transport: base.HttpClient.Transport, + } + // 发送请求 - resp, err := base.HttpClient.Do(req) + resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() + // 检查是否为URL过期错误(403, 410等状态码) + if resp.StatusCode == 403 || resp.StatusCode == 410 { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("upload url expired (status: %d): %s", resp.StatusCode, string(body)) + } + + // ✅ 关键修复:409 PartAlreadyExist 不是错误! + // 夸克使用Sequential模式,超时重试时如果分片已存在,说明第一次其实成功了 + if resp.StatusCode == 409 { + body, _ := io.ReadAll(resp.Body) + // 从响应体中提取已存在分片的ETag + if strings.Contains(string(body), "PartAlreadyExist") { + // 尝试从XML响应中提取ETag + if etag := extractEtagFromXML(string(body)); etag != "" { + log.Infof("[quark_open] Part %d already exists (409), using existing ETag: %s", partNumber+1, etag) + return etag, nil + } + // 如果无法提取ETag,返回错误 + log.Warnf("[quark_open] Part %d already exists but cannot extract ETag from response: %s", partNumber+1, string(body)) + return "", fmt.Errorf("part already exists but ETag not found in response") + } + } + if resp.StatusCode != 200 { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("up status: %d, error: %s", resp.StatusCode, string(body)) @@ -355,6 +400,23 @@ func (d *QuarkOpen) upPart(ctx context.Context, upUrlInfo UpUrlInfo, partNumber return resp.Header.Get("Etag"), nil } +// extractEtagFromXML 从OSS的XML错误响应中提取ETag +// 示例: "2F796AC486BB2891E3237D8BFDE020B5" +func extractEtagFromXML(xmlBody string) string { + start := strings.Index(xmlBody, "") + if start == -1 { + return "" + } + start += len("") + end := strings.Index(xmlBody[start:], "") + if end == -1 { + return "" + } + etag := xmlBody[start : start+end] + // 移除引号 + return strings.Trim(etag, "\"") +} + func (d *QuarkOpen) upFinish(ctx context.Context, pre UpPreResp, partInfo []base.Json, etags []string) error { // 创建 part_info_list partInfoList := make([]base.Json, len(partInfo)) @@ -417,25 +479,36 @@ func (d *QuarkOpen) generateReqSign(method string, pathname string, signKey stri } func (d *QuarkOpen) refreshToken() error { - refresh, access, err := d._refreshToken() + refresh, access, appID, signKey, err := d._refreshToken() for i := 0; i < 3; i++ { if err == nil { break } else { log.Errorf("[quark_open] failed to refresh token: %s", err) } - refresh, access, err = d._refreshToken() + refresh, access, appID, signKey, err = d._refreshToken() } if err != nil { return err } log.Infof("[quark_open] token exchange: %s -> %s", d.RefreshToken, refresh) d.RefreshToken, d.AccessToken = refresh, access + + // 如果在线API返回了AppID和SignKey,保存它们(不为空时才更新) + if appID != "" && appID != d.AppID { + d.AppID = appID + log.Infof("[quark_open] AppID updated from online API: %s", appID) + } + if signKey != "" && signKey != d.SignKey { + d.SignKey = signKey + log.Infof("[quark_open] SignKey updated from online API") + } + op.MustSaveDriverStorage(d) return nil } -func (d *QuarkOpen) _refreshToken() (string, string, error) { +func (d *QuarkOpen) _refreshToken() (string, string, string, string, error) { if d.UseOnlineAPI && d.APIAddress != "" { u := d.APIAddress var resp RefreshTokenOnlineAPIResp @@ -448,19 +521,20 @@ func (d *QuarkOpen) _refreshToken() (string, string, error) { }). Get(u) if err != nil { - return "", "", err + return "", "", "", "", err } if resp.RefreshToken == "" || resp.AccessToken == "" { if resp.ErrorMessage != "" { - return "", "", fmt.Errorf("failed to refresh token: %s", resp.ErrorMessage) + return "", "", "", "", fmt.Errorf("failed to refresh token: %s", resp.ErrorMessage) } - return "", "", fmt.Errorf("empty token returned from official API, a wrong refresh token may have been used") + return "", "", "", "", fmt.Errorf("empty token returned from official API, a wrong refresh token may have been used") } - return resp.RefreshToken, resp.AccessToken, nil + // 返回所有字段,包括AppID和SignKey + return resp.RefreshToken, resp.AccessToken, resp.AppID, resp.SignKey, nil } // TODO 本地刷新逻辑 - return "", "", fmt.Errorf("local refresh token logic is not implemented yet, please use online API or contact the developer") + return "", "", "", "", fmt.Errorf("local refresh token logic is not implemented yet, please use online API or contact the developer") } // 生成认证 Cookie diff --git a/drivers/sftp/types.go b/drivers/sftp/types.go index 00a32f001..a57076e08 100644 --- a/drivers/sftp/types.go +++ b/drivers/sftp/types.go @@ -48,8 +48,8 @@ func (d *SFTP) fileToObj(f os.FileInfo, dir string) (model.Obj, error) { Size: _f.Size(), Modified: _f.ModTime(), IsFolder: _f.IsDir(), - Path: target, + Path: path, // Use symlink's own path, not target path } - log.Debugf("[sftp] obj: %+v, is symlink: %v", obj, symlink) + log.Debugf("[sftp] obj: %+v, is symlink: %v, target: %s", obj, symlink, target) return obj, nil } diff --git a/drivers/teldrive/types.go b/drivers/teldrive/types.go index 084f967e6..d75347b89 100644 --- a/drivers/teldrive/types.go +++ b/drivers/teldrive/types.go @@ -52,7 +52,7 @@ type chunkTask struct { fileName string chunkSize int64 reader io.ReadSeeker - ss stream.StreamSectionReaderIF + ss stream.StreamSectionReader } type CopyManager struct { diff --git a/go.mod b/go.mod index cd86a8147..ef022561b 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/OpenListTeam/wopan-sdk-go v0.1.5 github.com/ProtonMail/go-crypto v1.3.0 github.com/ProtonMail/gopenpgp/v2 v2.9.0 - github.com/SheltonZhu/115driver v1.2.3 + github.com/SheltonZhu/115driver v1.3.3 github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible github.com/antchfx/htmlquery v1.3.5 github.com/antchfx/xpath v1.3.5 @@ -67,9 +67,9 @@ require ( github.com/rclone/rclone v1.70.3 github.com/shirou/gopsutil/v4 v4.25.5 github.com/sirupsen/logrus v1.9.3 - github.com/spf13/afero v1.14.0 - github.com/spf13/cobra v1.9.1 - github.com/stretchr/testify v1.10.0 + github.com/spf13/afero v1.15.0 + github.com/spf13/cobra v1.10.2 + github.com/stretchr/testify v1.11.1 github.com/t3rm1n4l/go-mega v0.0.0-20241213151442-a19cff0ec7b5 github.com/tchap/go-patricia/v2 v2.3.3 github.com/u2takey/ffmpeg-go v0.5.0 @@ -284,7 +284,7 @@ require ( github.com/shabbyrobe/gocovmerge v0.0.0-20230507112040-c3350d9342df // indirect github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/spaolacci/murmur3 v1.1.0 // indirect - github.com/spf13/pflag v1.0.6 // indirect + github.com/spf13/pflag v1.0.10 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect github.com/tklauser/numcpus v0.10.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect @@ -312,4 +312,6 @@ replace github.com/ProtonMail/go-proton-api => github.com/henrybear327/go-proton replace github.com/cronokirby/saferith => github.com/Da3zKi7/saferith v0.33.0-fixed -// replace github.com/OpenListTeam/115-sdk-go => ../../OpenListTeam/115-sdk-go +replace github.com/OpenListTeam/115-sdk-go => github.com/Ironboxplus/115-sdk-go v0.2.8 + +replace github.com/KarpelesLab/reflink => github.com/OpenListTeam/reflink v0.0.0-20260520031008-ed3c0dbe8009 diff --git a/go.sum b/go.sum index 758741249..adb42d29f 100644 --- a/go.sum +++ b/go.sum @@ -21,20 +21,20 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/Da3zKi7/saferith v0.33.0-fixed h1:fnIWTk7EP9mZAICf7aQjeoAwpfrlCrkOvqmi6CbWdTk= github.com/Da3zKi7/saferith v0.33.0-fixed/go.mod h1:QKJhjoqUtBsXCAVEjw38mFqoi7DebT7kthcD7UzbnoA= -github.com/KarpelesLab/reflink v1.0.2 h1:hQ1aM3TmjU2kTNUx5p/HaobDoADYk+a6AuEinG4Cv88= -github.com/KarpelesLab/reflink v1.0.2/go.mod h1:WGkTOKNjd1FsJKBw3mu4JvrPEDJyJJ+JPtxBkbPoCok= +github.com/Ironboxplus/115-sdk-go v0.2.8 h1:JyRGXDDXktItPJansGzyLpziF1UhY30xzC5LRYTgRp4= +github.com/Ironboxplus/115-sdk-go v0.2.8/go.mod h1:cfvitk2lwe6036iNi2h+iNxwxWDifKZsSvNtrur5BqU= github.com/KirCute/zip v1.0.1 h1:L/tVZglOiDVKDi9Ud+fN49htgKdQ3Z0H80iX8OZk13c= github.com/KirCute/zip v1.0.1/go.mod h1:xhF7dCB+Bjvy+5a56lenYCKBsH+gxDNPZSy5Cp+nlXk= github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g= github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd h1:nzE1YQBdx1bq9IlZinHa+HVffy+NmVRoKr+wHN8fpLE= github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd/go.mod h1:C8yoIfvESpM3GD07OCHU7fqI7lhwyZ2Td1rbNbTAhnc= -github.com/OpenListTeam/115-sdk-go v0.2.3 h1:nDNz0GxgliW+nT2Ds486k/rp/GgJj7Ngznc98ZBUwZo= -github.com/OpenListTeam/115-sdk-go v0.2.3/go.mod h1:cfvitk2lwe6036iNi2h+iNxwxWDifKZsSvNtrur5BqU= github.com/OpenListTeam/go-cache v0.1.0 h1:eV2+FCP+rt+E4OCJqLUW7wGccWZNJMV0NNkh+uChbAI= github.com/OpenListTeam/go-cache v0.1.0/go.mod h1:AHWjKhNK3LE4rorVdKyEALDHoeMnP8SjiNyfVlB+Pz4= github.com/OpenListTeam/gsync v0.1.0 h1:ywzGybOvA3lW8K1BUjKZ2IUlT2FSlzPO4DOazfYXjcs= github.com/OpenListTeam/gsync v0.1.0/go.mod h1:h/Rvv9aX/6CdW/7B8di3xK3xNV8dUg45Fehrd/ksZ9s= +github.com/OpenListTeam/reflink v0.0.0-20260520031008-ed3c0dbe8009 h1:qLqJPr/FAsZTJiqy65JKKuFJP0V9pRVtSaIE0kqaQ8w= +github.com/OpenListTeam/reflink v0.0.0-20260520031008-ed3c0dbe8009/go.mod h1:WGkTOKNjd1FsJKBw3mu4JvrPEDJyJJ+JPtxBkbPoCok= github.com/OpenListTeam/sftpd-openlist v1.0.1 h1:j4S3iPFOpnXCUKRPS7uCT4mF2VCl34GyqvH6lqwnkUU= github.com/OpenListTeam/sftpd-openlist v1.0.1/go.mod h1:uO/wKnbvbdq3rBLmClMTZXuCnw7XW4wlAq4dZe91a40= github.com/OpenListTeam/tache v0.2.2 h1:CWFn6sr1AIYaEjC8ONdKs+LrxHyuErheenAjEqRhh4k= @@ -63,8 +63,8 @@ github.com/RoaringBitmap/roaring/v2 v2.4.5 h1:uGrrMreGjvAtTBobc0g5IrW1D5ldxDQYe2 github.com/RoaringBitmap/roaring/v2 v2.4.5/go.mod h1:FiJcsfkGje/nZBZgCu0ZxCPOKD/hVXDS2dXi7/eUFE0= github.com/STARRY-S/zip v0.2.1 h1:pWBd4tuSGm3wtpoqRZZ2EAwOmcHK6XFf7bU9qcJXyFg= github.com/STARRY-S/zip v0.2.1/go.mod h1:xNvshLODWtC4EJ702g7cTYn13G53o1+X9BWnPFpcWV4= -github.com/SheltonZhu/115driver v1.2.3 h1:94XMP/ey7VXIlpoBLIJHEoXu7N8YsELZlXVbxWcDDvk= -github.com/SheltonZhu/115driver v1.2.3/go.mod h1:Zk7Qz7SYO1QU0SJIne6DnUD2k36S3wx/KbsQpxcfY/Y= +github.com/SheltonZhu/115driver v1.3.3 h1:Bqs86D2MziYPgIOuOJF+HzG4d7GBr71ZhSCrs/U17UU= +github.com/SheltonZhu/115driver v1.3.3/go.mod h1:OujS7azslg1/bn85sPSHnNsp4/WBI9/TiijtZL9kuSQ= github.com/abbot/go-http-auth v0.4.0 h1:QjmvZ5gSC7jm3Zg54DqWE/T5m1t2AfDu6QlXJT0EVT0= github.com/abbot/go-http-auth v0.4.0/go.mod h1:Cz6ARTIzApMJDzh5bRMSUou6UMSp0IEXg9km/ci7TJM= github.com/aead/ecdh v0.2.0 h1:pYop54xVaq/CEREFEcukHRZfTdjiWvYIsZDXXrBapQQ= @@ -612,12 +612,13 @@ github.com/sorairolake/lzip-go v0.3.5/go.mod h1:N0KYq5iWrMXI0ZEXKXaS9hCyOjZUQdBD github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= -github.com/spf13/afero v1.14.0 h1:9tH6MapGnn/j0eb0yIXiLjERO8RB6xIVZRDCX7PtqWA= -github.com/spf13/afero v1.14.0/go.mod h1:acJQ8t0ohCGuMN3O+Pv0V0hgMxNYDlvdk+VTfyZmbYo= -github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= -github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= -github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= -github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -632,8 +633,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/t3rm1n4l/go-mega v0.0.0-20241213151442-a19cff0ec7b5 h1:Sa+sR8aaAMFwxhXWENEnE6ZpqhZ9d7u1RT2722Rw6hc= github.com/t3rm1n4l/go-mega v0.0.0-20241213151442-a19cff0ec7b5/go.mod h1:UdZiFUFu6e2WjjtjxivwXWcwc1N/8zgbkBR9QNucUOY= github.com/taruti/bytepool v0.0.0-20160310082835-5e3a9ea56543 h1:6Y51mutOvRGRx6KqyMNo//xk8B8o6zW9/RVmy1VamOs= @@ -692,6 +693,7 @@ go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go4.org v0.0.0-20260112195520-a5071408f32f h1:ziUVAjmTPwQMBmYR1tbdRFJPtTcQUI12fH9QQjfb0Sw= go4.org v0.0.0-20260112195520-a5071408f32f/go.mod h1:ZRJnO5ZI4zAwMFp+dS1+V6J6MSyAowhRqAE+DPa1Xp0= gocv.io/x/gocv v0.25.0/go.mod h1:Rar2PS6DV+T4FL+PM535EImD/h13hGVaHhnCu1xarBs= diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index 74e218f8f..830446808 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -1,6 +1,7 @@ package bootstrap import ( + "math" "net/url" "os" "path/filepath" @@ -96,27 +97,46 @@ func InitConfig() { confFromEnv() } - if conf.Conf.MaxConcurrency > 0 { - net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: conf.Conf.MaxConcurrency} + if conf.Conf.MaxConcurrency > math.MaxInt32 { + net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: math.MaxInt32} + } else if conf.Conf.MaxConcurrency > 0 { + net.DefaultConcurrencyLimit = &net.ConcurrencyLimit{Limit: uint32(conf.Conf.MaxConcurrency)} } - if conf.Conf.MaxBufferLimit < 0 { - m, _ := mem.VirtualMemory() - if m != nil { - conf.MaxBufferLimit = max(int(float64(m.Total)*0.05), 4*utils.MB) - conf.MaxBufferLimit -= conf.MaxBufferLimit % utils.MB + + memStat, _ := mem.VirtualMemory() + if memStat != nil { + log.Infof("total memory: %dMB, available: %dMB", memStat.Total>>20, memStat.Available>>20) + if conf.Conf.MinFreeMemory < 0 { + conf.MinFreeMemory = 0 + log.Info("disable memory cache") } else { - conf.MaxBufferLimit = 16 * utils.MB + if conf.Conf.MinFreeMemory < 16 { + t := (memStat.Total >> 20) / 10 + conf.MinFreeMemory = max(16, min(t, 1024)) << 20 + } else { + conf.MinFreeMemory = uint64(conf.Conf.MinFreeMemory) << 20 + } + log.Infof("min free memory: %dMB", conf.MinFreeMemory>>20) } + + if conf.Conf.MaxBlockLimit < 4 { + t := (memStat.Total >> 20) * 3 / 100 + conf.MaxBlockLimit = max(4, min(uint64(t), 64)) << 20 + } else { + conf.MaxBlockLimit = uint64(conf.Conf.MaxBlockLimit) << 20 + } + log.Infof("max block limit: %dMB", conf.MaxBlockLimit>>20) } else { - conf.MaxBufferLimit = conf.Conf.MaxBufferLimit * utils.MB + conf.MinFreeMemory = 0 + log.Warn("failed to get memory info, disable memory cache") } - log.Infof("max buffer limit: %dMB", conf.MaxBufferLimit/utils.MB) - if conf.Conf.MmapThreshold > 0 { - conf.MmapThreshold = conf.Conf.MmapThreshold * utils.MB + + if conf.Conf.AutoMemoryLimit > 0 { + conf.AutoMemoryLimit = uint64(conf.Conf.AutoMemoryLimit) << 20 } else { - conf.MmapThreshold = 0 + conf.AutoMemoryLimit = 0 } - log.Infof("mmap threshold: %dMB", conf.Conf.MmapThreshold) + log.Infof("auto memory limit: %dMB", conf.AutoMemoryLimit>>20) if len(conf.Conf.Log.Filter.Filters) == 0 { conf.Conf.Log.Filter.Enable = false diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index d7fd8ea47..4526aa523 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -34,6 +34,13 @@ func initSettings() { } settingMap := map[string]*model.SettingItem{} for _, v := range settings { + if v.Key == "" { + err := db.DeleteSettingItemByKey(v.Key) + if err != nil { + utils.Log.Errorf("failed delete setting with empty key: %+v", err) + } + continue + } if !isActive(v.Key) && v.Flag != model.DEPRECATED { v.Flag = model.DEPRECATED err = op.SaveSettingItem(&v) diff --git a/internal/bootstrap/run.go b/internal/bootstrap/run.go index 6740dba65..ff02509ba 100644 --- a/internal/bootstrap/run.go +++ b/internal/bootstrap/run.go @@ -15,6 +15,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/db" "github.com/OpenListTeam/OpenList/v4/internal/fs" + "github.com/OpenListTeam/OpenList/v4/internal/frontend" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/server" "github.com/OpenListTeam/OpenList/v4/server/middlewares" @@ -273,6 +274,7 @@ func Start() { func Shutdown(timeout time.Duration) { utils.Log.Println("Shutdown server...") + frontend.StopWatcher() fs.ArchiveContentUploadTaskManager.RemoveAll() ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() diff --git a/internal/conf/config.go b/internal/conf/config.go index f347380d8..bf2d3837f 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -118,10 +118,12 @@ type Config struct { TempDir string `json:"temp_dir" env:"TEMP_DIR"` BleveDir string `json:"bleve_dir" env:"BLEVE_DIR"` DistDir string `json:"dist_dir"` + FrontendRepo string `json:"frontend_repo" env:"FRONTEND_REPO"` Log LogConfig `json:"log" envPrefix:"LOG_"` DelayedStart int `json:"delayed_start" env:"DELAYED_START"` - MaxBufferLimit int `json:"max_buffer_limitMB" env:"MAX_BUFFER_LIMIT_MB"` - MmapThreshold int `json:"mmap_thresholdMB" env:"MMAP_THRESHOLD_MB"` + AutoMemoryLimit int `json:"auto_memory_limit" env:"AUTO_MEMORY_LIMIT"` + MinFreeMemory int `json:"min_free_memory" env:"MIN_FREE_MEMORY"` + MaxBlockLimit int `json:"max_block_limit" env:"MAX_BLOCK_LIMIT"` MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"` MaxConcurrency int `json:"max_concurrency" env:"MAX_CONCURRENCY"` TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"` @@ -162,7 +164,8 @@ func DefaultConfig(dataDir string) *Config { Host: "http://localhost:7700", Index: "openlist", }, - BleveDir: indexDir, + BleveDir: indexDir, + FrontendRepo: FrontendRepoDefault, Log: LogConfig{ Enable: true, Name: logPath, @@ -178,8 +181,7 @@ func DefaultConfig(dataDir string) *Config { }, }, }, - MaxBufferLimit: -1, - MmapThreshold: 4, + AutoMemoryLimit: 4, MaxConnections: 0, MaxConcurrency: 64, TlsInsecureSkipVerify: false, diff --git a/internal/conf/var.go b/internal/conf/var.go index 972f69997..76de34587 100644 --- a/internal/conf/var.go +++ b/internal/conf/var.go @@ -12,6 +12,7 @@ var ( GitCommit string = "unknown" Version string = "dev" WebVersion string = "rolling" + FrontendRepoDefault string = "Ironboxplus/OpenList-Frontend" ) var ( @@ -25,11 +26,16 @@ var FilenameCharMap = make(map[string]string) var PrivacyReg []*regexp.Regexp var ( - // 单个Buffer最大限制 - MaxBufferLimit = 16 * 1024 * 1024 - // 超过该阈值的Buffer将使用 mmap 分配,可主动释放内存 - MmapThreshold = 4 * 1024 * 1024 + // 在HybridCache中使用[]byte缓存数据流的限制,内存为Go自动管理,直到GC + AutoMemoryLimit uint64 = 4 * 1024 * 1024 + // 最小空闲内存,当内存不足时,HybridCache会回退到文件缓存。 + // 如果为0,HybridCache会使用文件缓存,不占用内存。 + MinFreeMemory uint64 = 16 * 1024 * 1024 + // 限制HybridCache手动管理内存单次的扩容大小,超过该阈值将分多次扩容。 + // MinFreeMemory大于0时,也限制 Downloader 的PartSize + MaxBlockLimit uint64 = 16 * 1024 * 1024 ) + var ( RawIndexHtml string ManageHtml string diff --git a/internal/db/settingitem.go b/internal/db/settingitem.go index f20e507f0..0d42aca6d 100644 --- a/internal/db/settingitem.go +++ b/internal/db/settingitem.go @@ -65,5 +65,5 @@ func SaveSettingItem(item *model.SettingItem) error { } func DeleteSettingItemByKey(key string) error { - return errors.WithStack(db.Delete(&model.SettingItem{Key: key}).Error) + return errors.WithStack(db.Where(fmt.Sprintf("%s = ?", columnName("key")), key).Delete(model.SettingItem{}).Error) } diff --git a/internal/frontend/fetcher.go b/internal/frontend/fetcher.go new file mode 100644 index 000000000..b08761578 --- /dev/null +++ b/internal/frontend/fetcher.go @@ -0,0 +1,507 @@ +package frontend + +import ( + "archive/tar" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/OpenListTeam/OpenList/v4/cmd/flags" + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" +) + +const ( + defaultFrontendRepo = "OpenListTeam/OpenList-Frontend" + versionFile = ".frontend_version" + distDirName = "dist" + maxExtractFileSize = 500 * 1024 * 1024 // 500MB per file +) + +// FetchResult contains the result of a fetch operation +type FetchResult struct { + Version string + Downloaded bool + DistPath string +} + +// GetDistPath returns the path where dynamically fetched frontend dist is stored +func GetDistPath() string { + return filepath.Join(flags.DataDir, "frontend_dist") +} + +// GetVersionFilePath returns the path to the version tracking file +func GetVersionFilePath() string { + return filepath.Join(GetDistPath(), versionFile) +} + +// HasValidDist checks if the dynamic dist directory exists and has an index.html +func HasValidDist() bool { + distPath := GetDistPath() + _, err := os.Stat(filepath.Join(distPath, distDirName, "index.html")) + return err == nil +} + +// ReadCurrentVersion reads the currently cached version from disk +func ReadCurrentVersion() string { + data, err := os.ReadFile(GetVersionFilePath()) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +// writeVersion writes the version string to the version tracking file +func writeVersion(version string) error { + return os.WriteFile(GetVersionFilePath(), []byte(version), 0644) +} + +// FetchFromRolling downloads the frontend dist from the GitHub rolling release +func FetchFromRolling(ctx context.Context) (*FetchResult, error) { + return fetchFromTag(ctx, "rolling", "") +} + +// FetchFromLatest downloads the frontend dist from the GitHub latest release +func FetchFromLatest(ctx context.Context) (*FetchResult, error) { + return fetchFromTag(ctx, "", "") +} + +// githubRelease represents a GitHub release for JSON parsing +type githubRelease struct { + TagName string `json:"tag_name"` + Assets []struct { + BrowserDownloadURL string `json:"browser_download_url"` + Name string `json:"name"` + } `json:"assets"` + PublishedAt string `json:"published_at"` +} + +type githubRef struct { + Object struct { + Type string `json:"type"` + SHA string `json:"sha"` + } `json:"object"` +} + +type githubAnnotatedTag struct { + Object struct { + Type string `json:"type"` + SHA string `json:"sha"` + } `json:"object"` +} + +func getFrontendRepo() string { + if conf.Conf != nil && strings.TrimSpace(conf.Conf.FrontendRepo) != "" { + return strings.TrimSpace(conf.Conf.FrontendRepo) + } + return defaultFrontendRepo +} + +func shortHash(sha string) string { + const shortLen = 12 + if len(sha) > shortLen { + return sha[:shortLen] + } + return sha +} + +func versionIdentifier(tag, commitSHA, fallback string) string { + if strings.TrimSpace(commitSHA) != "" { + return fmt.Sprintf("%s@%s", tag, shortHash(commitSHA)) + } + if strings.TrimSpace(fallback) != "" { + return fallback + } + return tag +} + +func resolveTagCommitSHA(ctx context.Context, client *http.Client, baseURL, tag string) (string, error) { + if strings.TrimSpace(tag) == "" { + return "", fmt.Errorf("empty tag") + } + + apiBase := "https://api.github.com" + if baseURL != "" { + apiBase = strings.TrimRight(baseURL, "/") + } + + repo := getFrontendRepo() + refURL := fmt.Sprintf("%s/repos/%s/git/ref/tags/%s", apiBase, repo, url.PathEscape(tag)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, refURL, nil) + if err != nil { + return "", fmt.Errorf("create ref request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "OpenList") + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("fetch tag ref: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("tag ref API returned %d: %s", resp.StatusCode, string(body)) + } + + var ref githubRef + if err := json.NewDecoder(resp.Body).Decode(&ref); err != nil { + return "", fmt.Errorf("decode ref JSON: %w", err) + } + + switch ref.Object.Type { + case "commit": + if ref.Object.SHA == "" { + return "", fmt.Errorf("empty commit sha in ref response") + } + return ref.Object.SHA, nil + case "tag": + if ref.Object.SHA == "" { + return "", fmt.Errorf("empty tag sha in ref response") + } + tagObjURL := fmt.Sprintf("%s/repos/%s/git/tags/%s", apiBase, repo, ref.Object.SHA) + tagReq, err := http.NewRequestWithContext(ctx, http.MethodGet, tagObjURL, nil) + if err != nil { + return "", fmt.Errorf("create tag object request: %w", err) + } + tagReq.Header.Set("Accept", "application/vnd.github.v3+json") + tagReq.Header.Set("User-Agent", "OpenList") + + tagResp, err := client.Do(tagReq) + if err != nil { + return "", fmt.Errorf("fetch tag object: %w", err) + } + defer tagResp.Body.Close() + + if tagResp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(tagResp.Body) + return "", fmt.Errorf("tag object API returned %d: %s", tagResp.StatusCode, string(body)) + } + + var tagObj githubAnnotatedTag + if err := json.NewDecoder(tagResp.Body).Decode(&tagObj); err != nil { + return "", fmt.Errorf("decode tag object JSON: %w", err) + } + if tagObj.Object.SHA == "" { + return "", fmt.Errorf("empty object sha in tag object response") + } + return tagObj.Object.SHA, nil + default: + if ref.Object.SHA == "" { + return "", fmt.Errorf("unsupported ref object type %q with empty sha", ref.Object.Type) + } + return ref.Object.SHA, nil + } +} + +// fetchFromTag downloads frontend dist from a GitHub release tag. +// If baseURL is non-empty, it replaces api.github.com (used for testing). +func fetchFromTag(ctx context.Context, tag string, baseURL string) (*FetchResult, error) { + repo := getFrontendRepo() + var apiURL string + if baseURL != "" { + if tag == "" { + apiURL = fmt.Sprintf("%s/repos/%s/releases/latest", baseURL, repo) + } else { + apiURL = fmt.Sprintf("%s/repos/%s/releases/tags/%s", baseURL, repo, tag) + } + } else { + if tag == "" { + apiURL = fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo) + } else { + apiURL = fmt.Sprintf("https://api.github.com/repos/%s/releases/tags/%s", repo, tag) + } + } + + client := newHTTPClient() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "OpenList") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch release info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("github API returned %d: %s", resp.StatusCode, string(body)) + } + + var release githubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, fmt.Errorf("decode release JSON: %w", err) + } + + // Find the dist tarball URL (non-lite) + var tarURL string + for _, asset := range release.Assets { + if strings.Contains(asset.Name, "openlist-frontend-dist") && + !strings.Contains(asset.Name, "lite") && + strings.HasSuffix(asset.Name, ".tar.gz") { + tarURL = asset.BrowserDownloadURL + break + } + } + if tarURL == "" { + return nil, fmt.Errorf("no frontend dist tarball found in release %s", release.TagName) + } + + commitSHA, err := resolveTagCommitSHA(ctx, client, baseURL, release.TagName) + if err != nil { + utils.Log.Warnf("[frontend] failed to resolve tag %s hash: %v", release.TagName, err) + } + + resolvedVersion := versionIdentifier(release.TagName, commitSHA, tarURL) + + // Use tag+commit-hash as the primary version identifier. + // For rolling releases the tag itself is static, but its target commit moves. + // If hash resolve fails, fallback to tarball URL so updates can still be detected. + currentVersion := ReadCurrentVersion() + if currentVersion == resolvedVersion && HasValidDist() { + utils.Log.Infof("[frontend] version %s already cached, skipping download", resolvedVersion) + return &FetchResult{ + Version: resolvedVersion, + Downloaded: false, + DistPath: filepath.Join(GetDistPath(), distDirName), + }, nil + } + + utils.Log.Infof("[frontend] downloading version %s from %s", resolvedVersion, tarURL) + if err := downloadAndExtract(ctx, client, tarURL); err != nil { + return nil, fmt.Errorf("download and extract: %w", err) + } + + if err := writeVersion(resolvedVersion); err != nil { + utils.Log.Warnf("[frontend] failed to write version file: %v", err) + } + + utils.Log.Infof("[frontend] successfully fetched version %s", resolvedVersion) + return &FetchResult{ + Version: resolvedVersion, + Downloaded: true, + DistPath: filepath.Join(GetDistPath(), distDirName), + }, nil +} + +func downloadAndExtract(ctx context.Context, client *http.Client, url string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("create download request: %w", err) + } + req.Header.Set("User-Agent", "OpenList") + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("download tarball: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download returned status %d", resp.StatusCode) + } + + destDir := GetDistPath() + tmpDir := filepath.Join(destDir, ".tmp-"+fmt.Sprintf("%d", time.Now().UnixNano())) + if err := os.MkdirAll(tmpDir, 0755); err != nil { + return fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + if err := extractTarGz(resp.Body, tmpDir); err != nil { + return fmt.Errorf("extract tar.gz: %w", err) + } + + // Determine the source directory: + // If the tarball contains a "dist" subdirectory, use it; + // otherwise, the files are at the root and we use tmpDir directly. + srcDir := tmpDir + if _, err := os.Stat(filepath.Join(tmpDir, distDirName)); err == nil { + srcDir = filepath.Join(tmpDir, distDirName) + } + + // Atomic swap: hold lock to minimize the window where dist is absent + finalDir := filepath.Join(destDir, distDirName) + oldDir := filepath.Join(destDir, distDirName+".old") + + distSwapMu.Lock() + os.RemoveAll(oldDir) + os.Rename(finalDir, oldDir) + if err := os.Rename(srcDir, finalDir); err != nil { + os.RemoveAll(finalDir) + os.Rename(oldDir, finalDir) + distSwapMu.Unlock() + return fmt.Errorf("rename new dist: %w", err) + } + distSwapMu.Unlock() + os.RemoveAll(oldDir) + + return nil +} + +func extractTarGz(r io.Reader, dest string) error { + gzr, err := gzip.NewReader(r) + if err != nil { + return fmt.Errorf("gzip reader: %w", err) + } + defer gzr.Close() + + tr := tar.NewReader(gzr) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("tar next: %w", err) + } + + // Normalize: strip leading ./ so "./dist" becomes "dist" + name := strings.TrimPrefix(hdr.Name, "./") + if name == "" || name == "." { + continue // skip bare directory entry + } + + target := filepath.Join(dest, name) + + // Security: prevent path traversal + if !strings.HasPrefix(filepath.Clean(target), filepath.Clean(dest)+string(os.PathSeparator)) { + return fmt.Errorf("path traversal detected: %s", hdr.Name) + } + + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, os.FileMode(hdr.Mode)); err != nil { + return err + } + case tar.TypeReg: + if hdr.Size > maxExtractFileSize { + return fmt.Errorf("file too large: %s (%d bytes)", hdr.Name, hdr.Size) + } + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return err + } + f, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) + if err != nil { + return err + } + if _, err := io.Copy(f, io.LimitReader(tr, maxExtractFileSize)); err != nil { + f.Close() + return err + } + f.Close() + } + } + return nil +} + +// newHTTPClient creates an HTTP client that respects proxy configuration +func newHTTPClient() *http.Client { + transport := &http.Transport{} + if conf.Conf != nil && conf.Conf.ProxyAddress != "" { + if proxyURL := mustParseURL(conf.Conf.ProxyAddress); proxyURL != nil { + transport.Proxy = http.ProxyURL(proxyURL) + } + } + return &http.Client{ + Transport: transport, + Timeout: 5 * time.Minute, + } +} + +func mustParseURL(raw string) *url.URL { + u, err := url.Parse(raw) + if err != nil { + utils.Log.Warnf("[frontend] invalid proxy URL %q: %v", raw, err) + return nil + } + return u +} + +// EnsureDist ensures a valid frontend dist is available, fetching if necessary. +// It first tries the dynamic dist, then falls back to fetching from GitHub. +// Returns the path to the dist directory, or empty string if no dist is available. +func EnsureDist(ctx context.Context) string { + // If user explicitly configured dist_dir, use that + if conf.Conf != nil && conf.Conf.DistDir != "" { + if _, err := os.Stat(filepath.Join(conf.Conf.DistDir, "index.html")); err == nil { + return conf.Conf.DistDir + } + } + + // Check if dynamic dist already exists + if HasValidDist() { + return filepath.Join(GetDistPath(), distDirName) + } + + // If auto-fetch is enabled (and WebVersion is rolling/beta/dev), try fetching + if shouldAutoFetch() { + utils.Log.Infof("[frontend] no local dist found, fetching from rolling release...") + result, err := FetchFromRolling(ctx) + if err != nil { + utils.Log.Warnf("[frontend] failed to fetch from rolling: %v", err) + // Fall through to return empty (embedded dist will be used as fallback) + return "" + } + return result.DistPath + } + + return "" +} + +func shouldAutoFetch() bool { + v := conf.WebVersion + return v == "" || v == "rolling" || v == "beta" || v == "dev" +} + +// Ensure the directory exists for the frontend dist +func init() { + _ = os.MkdirAll(GetDistPath(), 0755) +} + +var distSwapMu sync.Mutex + +// Ensure that the sync.Once pattern is used for the fetcher +var ( + fetchMu sync.Mutex + fetchDone bool + fetchResult string +) + +// EnsureDistOnce is a thread-safe version of EnsureDist that only fetches once per process. +// On failure, it does not lock the state so subsequent calls can retry. +func EnsureDistOnce(ctx context.Context) string { + fetchMu.Lock() + defer fetchMu.Unlock() + if fetchDone { + return fetchResult + } + result := EnsureDist(ctx) + if result != "" { + fetchResult = result + fetchDone = true + } + return result +} + +// ResetFetchState resets the fetch state (used for testing or re-fetch) +func ResetFetchState() { + fetchMu.Lock() + defer fetchMu.Unlock() + fetchDone = false + fetchResult = "" +} diff --git a/internal/frontend/fetcher_test.go b/internal/frontend/fetcher_test.go new file mode 100644 index 000000000..3ed6d9546 --- /dev/null +++ b/internal/frontend/fetcher_test.go @@ -0,0 +1,627 @@ +package frontend + +import ( + "archive/tar" + "compress/gzip" + "context" + "fmt" + "io" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" +) + +// createTestTarGz creates a tar.gz containing files with given names and contents +func createTestTarGz(t *testing.T, files map[string]string) []byte { + t.Helper() + pr, pw := io.Pipe() + go func() { + defer pw.Close() + gw := gzip.NewWriter(pw) + defer gw.Close() + tw := tar.NewWriter(gw) + defer tw.Close() + for name, content := range files { + hdr := &tar.Header{ + Name: name, + Mode: 0644, + Size: int64(len(content)), + } + if err := tw.WriteHeader(hdr); err != nil { + return + } + if _, err := tw.Write([]byte(content)); err != nil { + return + } + } + }() + data, err := io.ReadAll(pr) + if err != nil { + t.Fatalf("read tar.gz: %v", err) + } + return data +} + +func TestExtractTarGz(t *testing.T) { + tmpDir := t.TempDir() + files := map[string]string{ + "dist/index.html": "hello", + "dist/assets/app.js": "console.log('app')", + "dist/assets/style.css": "body {}", + "dist/images/logo.svg": "", + } + tarData := createTestTarGz(t, files) + + err := extractTarGz(strings.NewReader(string(tarData)), tmpDir) + if err != nil { + t.Fatalf("extractTarGz: %v", err) + } + + for name, expectedContent := range files { + path := filepath.Join(tmpDir, name) + data, err := os.ReadFile(path) + if err != nil { + t.Errorf("read %s: %v", name, err) + continue + } + if string(data) != expectedContent { + t.Errorf("content of %s: got %q, want %q", name, string(data), expectedContent) + } + } +} + +func TestExtractTarGzDotSlash(t *testing.T) { + tmpDir := t.TempDir() + files := map[string]string{ + "./dist/index.html": "dot-slash", + "./": "", + } + tarData := createTestTarGz(t, files) + + err := extractTarGz(strings.NewReader(string(tarData)), tmpDir) + if err != nil { + t.Fatalf("extractTarGz with ./ prefix: %v", err) + } + + data, err := os.ReadFile(filepath.Join(tmpDir, "dist", "index.html")) + if err != nil { + t.Fatalf("read dist/index.html: %v", err) + } + if string(data) != "dot-slash" { + t.Errorf("got %q, want dot-slash content", string(data)) + } +} + +func TestExtractTarGzPathTraversal(t *testing.T) { + tmpDir := t.TempDir() + files := map[string]string{ + "../../../etc/passwd": "root:x:0:0", + } + tarData := createTestTarGz(t, files) + + err := extractTarGz(strings.NewReader(string(tarData)), tmpDir) + if err == nil { + t.Fatal("expected error for path traversal, got nil") + } + if !strings.Contains(err.Error(), "path traversal") { + t.Errorf("expected path traversal error, got: %v", err) + } +} + +func TestExtractTarGzRejectsOversizedFile(t *testing.T) { + tmpDir := t.TempDir() + // Create a tar.gz with a file whose header claims a size exceeding the limit + pr, pw := io.Pipe() + go func() { + defer pw.Close() + gw := gzip.NewWriter(pw) + defer gw.Close() + tw := tar.NewWriter(gw) + defer tw.Close() + hdr := &tar.Header{ + Name: "dist/huge.bin", + Mode: 0644, + Size: maxExtractFileSize + 1, + } + _ = tw.WriteHeader(hdr) + // Write just enough to pass; the size check should reject before reading + buf := make([]byte, 1024) + for written := int64(0); written < hdr.Size; written += int64(len(buf)) { + n := min(int64(len(buf)), hdr.Size-written) + _, _ = tw.Write(buf[:n]) + } + }() + data, _ := io.ReadAll(pr) + + err := extractTarGz(strings.NewReader(string(data)), tmpDir) + if err == nil { + t.Fatal("expected error for oversized file, got nil") + } + if !strings.Contains(err.Error(), "too large") { + t.Errorf("expected 'too large' error, got: %v", err) + } +} + +func TestHasValidDist(t *testing.T) { + if HasValidDist() { + t.Log("HasValidDist returned true (may have existing dist from previous runs)") + } +} + +func TestWriteAndReadVersion(t *testing.T) { + _ = os.MkdirAll(GetDistPath(), 0755) + versionPath := GetVersionFilePath() + + origData, origErr := os.ReadFile(versionPath) + defer func() { + if origErr == nil { + _ = os.WriteFile(versionPath, origData, 0644) + } else { + _ = os.Remove(versionPath) + } + }() + + testVersion := "v1.0.0-test" + if err := writeVersion(testVersion); err != nil { + t.Fatalf("writeVersion: %v", err) + } + + got := ReadCurrentVersion() + if got != testVersion { + t.Errorf("ReadCurrentVersion: got %q, want %q", got, testVersion) + } +} + +func TestShouldAutoFetch(t *testing.T) { + origVersion := conf.WebVersion + defer func() { conf.WebVersion = origVersion }() + + tests := []struct { + version string + want bool + }{ + {"", true}, + {"rolling", true}, + {"beta", true}, + {"dev", true}, + {"v3.0.0", false}, + {"latest", false}, + } + + for _, tt := range tests { + t.Run(tt.version, func(t *testing.T) { + conf.WebVersion = tt.version + if got := shouldAutoFetch(); got != tt.want { + t.Errorf("shouldAutoFetch(%q) = %v, want %v", tt.version, got, tt.want) + } + }) + } +} + +func TestFetchFromRollingIntegration(t *testing.T) { + files := map[string]string{ + "./dist/index.html": "integration", + "./dist/assets/test.js": "console.log('test')", + } + tarData := createTestTarGz(t, files) + + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/repos/OpenListTeam/OpenList-Frontend/releases/tags/rolling": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(fmt.Sprintf(`{ + "tag_name": "rolling-test", + "assets": [{ + "name": "openlist-frontend-dist.tar.gz", + "browser_download_url": "%s/download/frontend.tar.gz" + }] + }`, ts.URL))) + case "/repos/OpenListTeam/OpenList-Frontend/git/ref/tags/rolling-test": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{ + "object": { + "type": "commit", + "sha": "0123456789abcdef0123456789abcdef01234567" + } + }`)) + case "/download/frontend.tar.gz": + w.Header().Set("Content-Type", "application/gzip") + w.WriteHeader(200) + w.Write(tarData) + default: + w.WriteHeader(404) + } + })) + defer ts.Close() + + ResetFetchState() + + destDir := GetDistPath() + os.RemoveAll(filepath.Join(destDir, distDirName)) + os.Remove(GetVersionFilePath()) + + ctx := context.Background() + result, err := fetchFromTag(ctx, "rolling", ts.URL) + if err != nil { + t.Fatalf("fetchFromTag: %v", err) + } + + if result.Version != "rolling-test@0123456789ab" { + t.Errorf("version: got %q, want %q", result.Version, "rolling-test@0123456789ab") + } + if !result.Downloaded { + t.Error("expected Downloaded=true") + } + + idx, err := os.ReadFile(filepath.Join(result.DistPath, "index.html")) + if err != nil { + t.Fatalf("read index.html: %v", err) + } + if string(idx) != "integration" { + t.Errorf("index.html: got %q", string(idx)) + } + + ver := ReadCurrentVersion() + expectedVer := "rolling-test@0123456789ab" + if ver != expectedVer { + t.Errorf("version file: got %q, want %q", ver, expectedVer) + } +} + +func TestLegacyConfigJSONGetsDefaultFrontendRepo(t *testing.T) { + cfg := conf.DefaultConfig("data") + if err := json.Unmarshal([]byte(`{"site_url":"https://example.com"}`), cfg); err != nil { + t.Fatalf("unmarshal legacy config: %v", err) + } + if cfg.FrontendRepo != conf.FrontendRepoDefault { + t.Fatalf("FrontendRepo: got %q, want %q", cfg.FrontendRepo, conf.FrontendRepoDefault) + } +} + +func TestDefaultConfigUsesBuiltFrontendRepoDefault(t *testing.T) { + orig := conf.FrontendRepoDefault + conf.FrontendRepoDefault = "Ironboxplus/OpenList-Frontend" + defer func() { conf.FrontendRepoDefault = orig }() + + cfg := conf.DefaultConfig("data") + if cfg.FrontendRepo != "Ironboxplus/OpenList-Frontend" { + t.Fatalf("FrontendRepo: got %q, want %q", cfg.FrontendRepo, "Ironboxplus/OpenList-Frontend") + } +} + +func TestExistingConfigJSONKeepsFrontendRepo(t *testing.T) { + cfg := conf.DefaultConfig("data") + if err := json.Unmarshal([]byte(`{"frontend_repo":"Ironboxplus/OpenList-Frontend"}`), cfg); err != nil { + t.Fatalf("unmarshal config with frontend_repo: %v", err) + } + if cfg.FrontendRepo != "Ironboxplus/OpenList-Frontend" { + t.Fatalf("FrontendRepo: got %q", cfg.FrontendRepo) + } +} + +func TestFetchFromTagUsesConfiguredFrontendRepo(t *testing.T) { + origConf := conf.Conf + if origConf == nil { + conf.Conf = &conf.Config{} + } else { + confCopy := *origConf + conf.Conf = &confCopy + } + defer func() { conf.Conf = origConf }() + conf.Conf.FrontendRepo = "Ironboxplus/OpenList-Frontend" + + files := map[string]string{ + "./dist/index.html": "custom-repo", + } + tarData := createTestTarGz(t, files) + + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/repos/Ironboxplus/OpenList-Frontend/releases/tags/rolling": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(fmt.Sprintf(`{ + "tag_name": "rolling-custom", + "assets": [{ + "name": "openlist-frontend-dist.tar.gz", + "browser_download_url": "%s/download/custom.tar.gz" + }] + }`, ts.URL))) + case "/repos/Ironboxplus/OpenList-Frontend/git/ref/tags/rolling-custom": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "object": { + "type": "commit", + "sha": "fedcba9876543210fedcba9876543210fedcba98" + } + }`)) + case "/download/custom.tar.gz": + w.Header().Set("Content-Type", "application/gzip") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(tarData) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + + ResetFetchState() + _ = os.MkdirAll(GetDistPath(), 0o755) + _ = os.RemoveAll(filepath.Join(GetDistPath(), distDirName)) + _ = os.Remove(GetVersionFilePath()) + + result, err := fetchFromTag(context.Background(), "rolling", ts.URL) + if err != nil { + t.Fatalf("fetchFromTag(custom repo): %v", err) + } + if result.Version != "rolling-custom@fedcba987654" { + t.Fatalf("version: got %q, want %q", result.Version, "rolling-custom@fedcba987654") + } + + data, err := os.ReadFile(filepath.Join(result.DistPath, "index.html")) + if err != nil { + t.Fatalf("read index.html: %v", err) + } + if string(data) != "custom-repo" { + t.Fatalf("index.html: got %q", string(data)) + } +} + +func TestResolveTagCommitSHA_AnnotatedTag(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/repos/OpenListTeam/OpenList-Frontend/git/ref/tags/rolling": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{ + "object": { + "type": "tag", + "sha": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + } + }`)) + case "/repos/OpenListTeam/OpenList-Frontend/git/tags/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{ + "object": { + "type": "commit", + "sha": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" + } + }`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + + sha, err := resolveTagCommitSHA(context.Background(), ts.Client(), ts.URL, "rolling") + if err != nil { + t.Fatalf("resolveTagCommitSHA: %v", err) + } + if sha != "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" { + t.Fatalf("sha: got %q, want %q", sha, "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb") + } +} + +func TestVersionIdentifierFallback(t *testing.T) { + tests := []struct { + name string + tag string + sha string + fallback string + want string + }{ + {name: "hash preferred", tag: "rolling", sha: "0123456789abcdef", fallback: "fallback", want: "rolling@0123456789ab"}, + {name: "fallback url", tag: "rolling", sha: "", fallback: "http://example.com/dist.tar.gz", want: "http://example.com/dist.tar.gz"}, + {name: "tag only", tag: "rolling", sha: "", fallback: "", want: "rolling"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := versionIdentifier(tt.tag, tt.sha, tt.fallback) + if got != tt.want { + t.Fatalf("versionIdentifier() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestResolveTagCommitSHA_RealGitHubWithProxy10808(t *testing.T) { + if os.Getenv("OPENLIST_REAL_GITHUB_TEST") != "1" { + t.Skip("set OPENLIST_REAL_GITHUB_TEST=1 to run real GitHub integration test (proxy 127.0.0.1:10808 recommended)") + } + + if os.Getenv("HTTP_PROXY") == "" && os.Getenv("http_proxy") == "" { + _ = os.Setenv("HTTP_PROXY", "http://127.0.0.1:10808") + } + if os.Getenv("HTTPS_PROXY") == "" && os.Getenv("https_proxy") == "" { + _ = os.Setenv("HTTPS_PROXY", "http://127.0.0.1:10808") + } + + client := newHTTPClient() + sha, err := resolveTagCommitSHA(context.Background(), client, "", "rolling") + if err != nil { + t.Fatalf("resolveTagCommitSHA(real): %v", err) + } + + matched, _ := regexp.MatchString("^[0-9a-f]{40}$", sha) + if !matched { + t.Fatalf("sha format invalid: %q", sha) + } + + // Optional sanity: ensure API can fetch release JSON in real scenario + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://api.github.com/repos/OpenListTeam/OpenList-Frontend/releases/tags/rolling", nil) + if err != nil { + t.Fatalf("create release request: %v", err) + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "OpenList") + resp, err := client.Do(req) + if err != nil { + t.Fatalf("fetch release: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("release API status=%d body=%s", resp.StatusCode, string(body)) + } + + var release map[string]any + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + t.Fatalf("decode release: %v", err) + } + if release["tag_name"] == nil { + t.Fatalf("release tag_name missing") + } +} + +func TestStaleCacheOverriddenByNewerRelease(t *testing.T) { + oldFiles := map[string]string{"./dist/index.html": "old"} + oldTar := createTestTarGz(t, oldFiles) + newFiles := map[string]string{"./dist/index.html": "new"} + newTar := createTestTarGz(t, newFiles) + + callCount := 0 + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case fmt.Sprintf("/repos/%s/releases/tags/rolling", getFrontendRepo()): + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(fmt.Sprintf(`{ + "tag_name": "rolling", + "assets": [{"name": "openlist-frontend-dist.tar.gz", "browser_download_url": "%s/download/dist.tar.gz"}] + }`, ts.URL))) + case fmt.Sprintf("/repos/%s/git/ref/tags/rolling", getFrontendRepo()): + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"object":{"type":"commit","sha":"aaaaaaaaaaaabbbbbbbbbbbbccccccccccccdddd"}}`)) + case "/download/dist.tar.gz": + callCount++ + w.Header().Set("Content-Type", "application/gzip") + w.WriteHeader(200) + if callCount == 1 { + _, _ = w.Write(oldTar) + } else { + _, _ = w.Write(newTar) + } + default: + w.WriteHeader(404) + } + })) + defer ts.Close() + + ResetFetchState() + destDir := GetDistPath() + _ = os.MkdirAll(destDir, 0755) + os.RemoveAll(filepath.Join(destDir, distDirName)) + os.Remove(GetVersionFilePath()) + + // Write a stale cached version + _ = writeVersion("rolling@stale000000000") + _ = os.MkdirAll(filepath.Join(destDir, distDirName), 0755) + _ = os.WriteFile(filepath.Join(destDir, distDirName, "index.html"), []byte("stale"), 0644) + + // Fetch should detect version mismatch and download + result, err := fetchFromTag(context.Background(), "rolling", ts.URL) + if err != nil { + t.Fatalf("fetchFromTag with stale cache: %v", err) + } + if !result.Downloaded { + t.Error("expected Downloaded=true when cache is stale") + } + + data, err := os.ReadFile(filepath.Join(result.DistPath, "index.html")) + if err != nil { + t.Fatalf("read index.html: %v", err) + } + if string(data) != "old" { + t.Errorf("index.html: got %q, want old content", string(data)) + } +} + +func TestWatcherTriggersCallbackOnNewVersion(t *testing.T) { + files := map[string]string{"./dist/index.html": "watcher"} + tarData := createTestTarGz(t, files) + + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case fmt.Sprintf("/repos/%s/releases/tags/rolling", getFrontendRepo()): + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(fmt.Sprintf(`{ + "tag_name": "rolling", + "assets": [{"name": "openlist-frontend-dist.tar.gz", "browser_download_url": "%s/download/dist.tar.gz"}] + }`, ts.URL))) + case fmt.Sprintf("/repos/%s/git/ref/tags/rolling", getFrontendRepo()): + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"object":{"type":"commit","sha":"watchertest1234567890watchertest1234567890"}}`)) + case "/download/dist.tar.gz": + w.Header().Set("Content-Type", "application/gzip") + w.WriteHeader(200) + _, _ = w.Write(tarData) + default: + w.WriteHeader(404) + } + })) + defer ts.Close() + + ResetFetchState() + destDir := GetDistPath() + _ = os.MkdirAll(destDir, 0755) + os.RemoveAll(filepath.Join(destDir, distDirName)) + os.Remove(GetVersionFilePath()) + + // The watcher's check() calls FetchFromRolling which uses api.github.com. + // For this test, call fetchFromTag directly and verify Downloaded triggers callback logic. + result, err := fetchFromTag(context.Background(), "rolling", ts.URL) + if err != nil { + t.Fatalf("fetchFromTag: %v", err) + } + if !result.Downloaded { + t.Fatal("expected Downloaded=true for watcher callback trigger") + } + if result.Version != "rolling@watchertest1" { + t.Errorf("version: got %q", result.Version) + } +} + +func TestEnsureDistOnceFailureDoesNotLock(t *testing.T) { + ResetFetchState() + _ = os.MkdirAll(GetDistPath(), 0755) + + destDir := GetDistPath() + os.RemoveAll(filepath.Join(destDir, distDirName)) + os.Remove(GetVersionFilePath()) + + origVersion := conf.WebVersion + conf.WebVersion = "v3.0.0" // shouldAutoFetch returns false + defer func() { conf.WebVersion = origVersion }() + + ctx := context.Background() + + result := EnsureDistOnce(ctx) + if result != "" { + t.Errorf("expected empty result, got %q", result) + } + + // Second call should also work (not locked by previous failure) + result2 := EnsureDistOnce(ctx) + if result2 != "" { + t.Errorf("expected empty result on retry, got %q", result2) + } +} diff --git a/internal/frontend/watcher.go b/internal/frontend/watcher.go new file mode 100644 index 000000000..1a0102fd1 --- /dev/null +++ b/internal/frontend/watcher.go @@ -0,0 +1,122 @@ +package frontend + +import ( + "context" + "sync" + "time" + + "github.com/OpenListTeam/OpenList/v4/pkg/utils" +) + +const defaultCheckInterval = 30 * time.Minute + +// Watcher periodically checks for new frontend versions and fetches them. +type Watcher struct { + interval time.Duration + stopCh chan struct{} + stopped bool + mu sync.Mutex + onUpdated func() +} + +// NewWatcher creates a new frontend watcher. +// onUpdated is called when a new version is fetched (used to reload static files). +func NewWatcher(onUpdated func()) *Watcher { + return &Watcher{ + interval: defaultCheckInterval, + stopCh: make(chan struct{}), + onUpdated: onUpdated, + } +} + +// SetInterval changes the check interval. Must be called before Start. +func (w *Watcher) SetInterval(d time.Duration) { + w.mu.Lock() + defer w.mu.Unlock() + w.interval = d +} + +// Start begins the periodic check loop in a background goroutine. +func (w *Watcher) Start() { + w.mu.Lock() + interval := w.interval + w.mu.Unlock() + + go func() { + utils.Log.Infof("[frontend] watcher started, checking every %s", interval) + // Check immediately on start, then periodically + w.check() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-w.stopCh: + utils.Log.Infof("[frontend] watcher stopped") + return + case <-ticker.C: + w.check() + } + } + }() +} + +// Stop signals the watcher to stop. +func (w *Watcher) Stop() { + w.mu.Lock() + defer w.mu.Unlock() + if w.stopped { + return + } + w.stopped = true + close(w.stopCh) +} + +func (w *Watcher) check() { + if !shouldAutoFetch() { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + result, err := FetchFromRolling(ctx) + if err != nil { + utils.Log.Warnf("[frontend] watcher check failed: %v", err) + return + } + + if result.Downloaded { + utils.Log.Infof("[frontend] watcher fetched new version: %s", result.Version) + if w.onUpdated != nil { + w.onUpdated() + } + } +} + +// globalWatcher is the singleton watcher instance +var ( + globalWatcher *Watcher + watcherMu sync.Mutex +) + +// StartWatcher starts the global frontend watcher. +func StartWatcher(onUpdated func()) { + watcherMu.Lock() + defer watcherMu.Unlock() + if globalWatcher != nil { + return + } + globalWatcher = NewWatcher(onUpdated) + globalWatcher.Start() +} + +// StopWatcher stops the global frontend watcher. +func StopWatcher() { + watcherMu.Lock() + defer watcherMu.Unlock() + if globalWatcher != nil { + globalWatcher.Stop() + globalWatcher = nil + } +} diff --git a/internal/fs/copy_move.go b/internal/fs/copy_move.go index e78fc9be8..379cc6153 100644 --- a/internal/fs/copy_move.go +++ b/internal/fs/copy_move.go @@ -17,6 +17,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/tache" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" ) type taskType uint8 @@ -192,21 +193,23 @@ func (t *FileTransferTask) RunWithNextTaskCallback(f func(nextTask *FileTransfer dstActualPath := stdpath.Join(t.DstActualPath, srcObj.GetName()) task_group.TransferCoordinator.AppendPayload(t.groupID, task_group.DstPathToHook(dstActualPath)) - existedObjs := make(map[string]bool) + // Pre-create the destination directory first + t.Status = "ensuring destination directory exists" + if err := op.MakeDir(t.Ctx(), t.DstStorage, dstActualPath); err != nil { + log.Warnf("[copy_move] failed to ensure destination dir [%s]: %v, will continue", dstActualPath, err) + // Continue anyway - the directory might exist but Get failed due to cache issues + } + + // Build the set of files already at dst so merge can skip them. + // Subdirectories are created on-demand: each sub-task's recursive + // RunWithNextTaskCallback calls MakeDir at line 198, and op.Put + // also calls MakeDir for the parent before uploading. op.MakeDir + // itself walks up the parent chain so deep trees are covered. + var existedObjs map[string]bool if t.TaskType == merge { - dstObjs, err := op.List(t.Ctx(), t.DstStorage, dstActualPath, model.ListArgs{}) - if err != nil && !errors.Is(err, errs.ObjectNotFound) { - // 目标文件夹不存在的情况不是错误,会在之后新建文件夹 - // 这种情况显然不需要统计existedObjs,dstObjs保持为nil,下面这个for将不会执行 - return errors.WithMessagef(err, "failed list dst [%s] objs", dstActualPath) - } - for _, obj := range dstObjs { - if err := t.Ctx().Err(); err != nil { - return err - } - if !obj.IsDir() { - existedObjs[obj.GetName()] = true - } + existedObjs, err = t.existingDstFiles(dstActualPath) + if err != nil { + return err } } @@ -263,6 +266,44 @@ func (t *FileTransferTask) RunWithNextTaskCallback(f func(nextTask *FileTransfer return op.Put(context.WithValue(t.Ctx(), conf.SkipHookKey, struct{}{}), t.DstStorage, t.DstActualPath, ss, t.SetProgress) } +// existingDstFiles wraps existingDstFilesFn with the storage-bound List call, +// for use by RunWithNextTaskCallback. Returns the set of file names already +// present at dstActualPath; a non-existent destination yields an empty map +// rather than an error so merge tasks can run against a fresh destination. +func (t *FileTransferTask) existingDstFiles(dstActualPath string) (map[string]bool, error) { + listDst := func(ctx context.Context, path string) ([]model.Obj, error) { + return op.List(ctx, t.DstStorage, path, model.ListArgs{}) + } + return existingDstFilesFn(t.Ctx(), listDst, dstActualPath) +} + +// existingDstFilesFn collects names of files (not directories) already +// present at dstPath. A non-existent destination (errs.ObjectNotFound, +// possibly wrapped) is treated as an empty result so merge tasks can run +// against fresh destinations — see PR #1898. Other List errors propagate. +// +// listDst is injected so tests can drive this without a real storage driver. +func existingDstFilesFn( + ctx context.Context, + listDst func(context.Context, string) ([]model.Obj, error), + dstPath string, +) (map[string]bool, error) { + dstObjs, err := listDst(ctx, dstPath) + if err != nil && !errors.Is(err, errs.ObjectNotFound) { + return nil, errors.WithMessagef(err, "failed list dst [%s] objs", dstPath) + } + existed := make(map[string]bool, len(dstObjs)) + for _, obj := range dstObjs { + if err := ctx.Err(); err != nil { + return nil, err + } + if !obj.IsDir() { + existed[obj.GetName()] = true + } + } + return existed, nil +} + var ( CopyTaskManager *tache.Manager[*FileTransferTask] MoveTaskManager *tache.Manager[*FileTransferTask] diff --git a/internal/fs/copy_move_test.go b/internal/fs/copy_move_test.go new file mode 100644 index 000000000..37869604b --- /dev/null +++ b/internal/fs/copy_move_test.go @@ -0,0 +1,496 @@ +package fs + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + pkgerrors "github.com/pkg/errors" +) + +// ---------- helpers ---------- + +func dirObj(name string) model.Obj { + return &model.Object{Name: name, IsFolder: true} +} + +func fileObj(name string) model.Obj { + return &model.Object{Name: name, IsFolder: false} +} + +// listRecorder records every call to listDst so tests can assert on +// invocation patterns when needed. +type listRecorder struct { + mu sync.Mutex + calls []string + respond func(path string) ([]model.Obj, error) +} + +func newListRecorder(respond func(string) ([]model.Obj, error)) *listRecorder { + return &listRecorder{respond: respond} +} + +func (r *listRecorder) listDst(_ context.Context, path string) ([]model.Obj, error) { + r.mu.Lock() + r.calls = append(r.calls, path) + r.mu.Unlock() + return r.respond(path) +} + +// ---------- tests: existingDstFilesFn ---------- +// +// These tests pin down the contract of the merge-mode existedObjs builder +// extracted from RunWithNextTaskCallback. They are the regression guard +// for the deletion of the BFS precreate logic: the only invariant that +// the BFS precreate quietly protected (and that #1898's ObjectNotFound +// tolerance independently fixes) lives in this function. + +// 1. dst exists but is empty → empty map, no error. +func TestExistingDstFilesFn_EmptyDst(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { return nil, nil }) + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 0 { + t.Fatalf("expected empty map, got %v", got) + } +} + +// 2. dst contains only files → all included. +func TestExistingDstFilesFn_OnlyFiles(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { + return []model.Obj{fileObj("a.txt"), fileObj("b.txt"), fileObj("c.txt")}, nil + }) + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, name := range []string{"a.txt", "b.txt", "c.txt"} { + if !got[name] { + t.Errorf("expected %q in existed set, got %v", name, got) + } + } + if len(got) != 3 { + t.Errorf("expected 3 entries, got %d: %v", len(got), got) + } +} + +// 3. dst contains only directories → empty map (dirs don't count as +// "existed" because merge only skips already-uploaded *files*). +func TestExistingDstFilesFn_OnlyDirs(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { + return []model.Obj{dirObj("subA"), dirObj("subB")}, nil + }) + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 0 { + t.Fatalf("dirs must not appear in existed-files map, got %v", got) + } +} + +// 4. dst contains mixed files and dirs → only files are recorded. +func TestExistingDstFilesFn_MixedFilesAndDirs(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { + return []model.Obj{ + fileObj("readme.md"), + dirObj("assets"), + fileObj("main.go"), + dirObj("pkg"), + fileObj("go.mod"), + }, nil + }) + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 3 { + t.Errorf("expected 3 file entries, got %d: %v", len(got), got) + } + for _, name := range []string{"readme.md", "main.go", "go.mod"} { + if !got[name] { + t.Errorf("expected %q in existed set, got %v", name, got) + } + } + for _, name := range []string{"assets", "pkg"} { + if got[name] { + t.Errorf("dir %q must NOT be in existed set, got %v", name, got) + } + } +} + +// 5. dst doesn't exist (raw errs.ObjectNotFound) → empty map, no error. +// THIS IS THE REGRESSION TEST FOR #1898. +func TestExistingDstFilesFn_DstDoesNotExist_RawError(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { + return nil, errs.ObjectNotFound + }) + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst/never-existed") + if err != nil { + t.Fatalf("ObjectNotFound on dst must be tolerated, got error: %v", err) + } + if len(got) != 0 { + t.Fatalf("expected empty map on non-existent dst, got %v", got) + } +} + +// 6. dst doesn't exist (wrapped via pkg/errors.WithMessage) → empty map. +// Guards the errors.Is unwrapping behavior that matters in practice: +// op.List wraps the underlying ObjectNotFound with context messages. +func TestExistingDstFilesFn_DstDoesNotExist_WrappedError(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { + // emulate the op.List wrapping path: GetUnwrap returns + // ObjectNotFound, list wraps with WithMessage twice. + return nil, pkgerrors.WithMessage(pkgerrors.WithMessage(errs.ObjectNotFound, "failed get dir"), "while listing") + }) + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err != nil { + t.Fatalf("wrapped ObjectNotFound must be tolerated, got error: %v", err) + } + if len(got) != 0 { + t.Fatalf("expected empty map, got %v", got) + } +} + +// 7. List returns a non-ObjectNotFound error (e.g. permission denied, +// I/O failure) → error propagates so the task fails fast. +func TestExistingDstFilesFn_OtherListError_Propagates(t *testing.T) { + sentinel := errors.New("permission denied") + rec := newListRecorder(func(string) ([]model.Obj, error) { + return nil, sentinel + }) + _, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err == nil { + t.Fatal("expected error to propagate, got nil") + } + if !errors.Is(err, sentinel) { + t.Fatalf("expected wrapped sentinel error, got: %v", err) + } +} + +// 8. Context cancelled before list → ctx.Err returned (the contract for +// listDst is to honor ctx; the caller's loop also checks ctx.Err). +func TestExistingDstFilesFn_CtxCancelledBeforeList(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + rec := newListRecorder(func(string) ([]model.Obj, error) { + // A real op.List would honor ctx and return ctx.Err. + return nil, ctx.Err() + }) + _, err := existingDstFilesFn(ctx, rec.listDst, "/dst") + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got: %v", err) + } +} + +// 9. Context cancelled mid-iteration → loop bails with ctx.Err and the +// partial map is not returned. +func TestExistingDstFilesFn_CtxCancelledDuringIteration(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Big enough list that cancelling after returning still gives the + // loop a chance to observe ctx.Err. + objs := make([]model.Obj, 1000) + for i := range objs { + objs[i] = fileObj(fmt.Sprintf("f-%d.txt", i)) + } + + rec := newListRecorder(func(string) ([]model.Obj, error) { + cancel() // cancel immediately after list returns + return objs, nil + }) + _, err := existingDstFilesFn(ctx, rec.listDst, "/dst") + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got: %v", err) + } +} + +// 10. dst contains duplicate file names (pathological but possible if a +// driver returns duplicates) → map dedupes silently, no panic, no error. +func TestExistingDstFilesFn_DuplicateNames(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { + return []model.Obj{ + fileObj("dup.txt"), + fileObj("dup.txt"), + fileObj("unique.txt"), + }, nil + }) + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 2 { + t.Fatalf("expected 2 dedup'd entries, got %d: %v", len(got), got) + } + if !got["dup.txt"] || !got["unique.txt"] { + t.Fatalf("expected both names, got %v", got) + } +} + +// 11. dst contains a large number of files → all included; the call +// completes within a reasonable budget (sanity, not strict perf). +func TestExistingDstFilesFn_LargeDst(t *testing.T) { + const N = 10000 + objs := make([]model.Obj, N) + for i := range objs { + objs[i] = fileObj(fmt.Sprintf("f-%05d.bin", i)) + } + rec := newListRecorder(func(string) ([]model.Obj, error) { return objs, nil }) + + start := time.Now() + got, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != N { + t.Fatalf("expected %d entries, got %d", N, len(got)) + } + // Generous bound — only catches catastrophic regressions (e.g. + // accidental O(N²) inside the loop). + if elapsed > 2*time.Second { + t.Fatalf("processing %d entries took %v, too slow", N, elapsed) + } +} + +// shouldSkipForMerge mirrors the skip decision at copy_move.go:221 so the +// test below pins both the helper (existingDstFilesFn) AND the call-site +// guard as a single contract. Any future refactor that touches either +// half of this contract will be caught. +func shouldSkipForMerge(srcObj model.Obj, existedFiles map[string]bool) bool { + return !srcObj.IsDir() && existedFiles[srcObj.GetName()] +} + +// TestMergeSkipDecision_DirsAreNeverSkipped is the explicit anti-regression +// test for the worry "when a subdir exists in dst, the src subdir is +// skipped and its contents are not merged". It walks every (src kind, dst +// state) combination and asserts the skip decision. +// +// Why this matters: if existingDstFilesFn ever started including dirs, or +// if the call-site guard dropped the `!obj.IsDir()` clause, src subdirs +// matching dst subdirs would silently stop recursing — and every file +// inside them would never get copied. That bug is invisible to a casual +// "did the task complete?" check and only shows up as quietly missing +// deep files. Lock it down. +func TestMergeSkipDecision_DirsAreNeverSkipped(t *testing.T) { + // dst already contains a mix: one file, two dirs. + dstContents := []model.Obj{ + fileObj("root_file.txt"), + dirObj("subA"), + dirObj("subB"), + } + rec := newListRecorder(func(string) ([]model.Obj, error) { return dstContents, nil }) + + existed, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Sanity: the dirs in dst must NOT be in existedObjs. If this fails, + // every src dir matching a dst dir name would be skipped. + if existed["subA"] || existed["subB"] { + t.Fatalf("CRITICAL: dirs in dst leaked into existed-files map — "+ + "matching src subdirs would be skipped and lose their contents. "+ + "got %v", existed) + } + if !existed["root_file.txt"] { + t.Fatalf("file in dst missing from existed map: %v", existed) + } + + // Exhaustive skip-decision matrix for every src-object kind against + // the populated existedObjs. + cases := []struct { + name string + srcObj model.Obj + wantSkip bool + why string + }{ + { + name: "src_file_matches_dst_file", + srcObj: fileObj("root_file.txt"), + wantSkip: true, + why: "resume semantics: already-uploaded file is skipped", + }, + { + name: "src_dir_matches_dst_dir", + srcObj: dirObj("subA"), + wantSkip: false, + why: "MUST recurse into matching subdir to merge its contents", + }, + { + name: "src_dir_matches_dst_dir_B", + srcObj: dirObj("subB"), + wantSkip: false, + why: "MUST recurse — every dst dir must trigger recursion regardless of name", + }, + { + name: "src_dir_no_match", + srcObj: dirObj("brand_new_dir"), + wantSkip: false, + why: "new dir → recurse and create", + }, + { + name: "src_file_no_match", + srcObj: fileObj("brand_new_file.txt"), + wantSkip: false, + why: "new file → upload", + }, + { + name: "src_dir_matches_dst_FILE_name", + srcObj: dirObj("root_file.txt"), + wantSkip: false, + why: "even with a name collision against a dst FILE, src dir must NOT be skipped " + + "(the conflict will surface later when MakeDir runs, not silently)", + }, + { + name: "src_file_matches_dst_DIR_name", + srcObj: fileObj("subA"), + wantSkip: false, + why: "src file colliding with a dst dir name must NOT be skipped " + + "(dirs are not in existed map, and op.Put will surface the conflict)", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := shouldSkipForMerge(c.srcObj, existed) + if got != c.wantSkip { + t.Fatalf("skip(src=%s,IsDir=%v) = %v, want %v — %s", + c.srcObj.GetName(), c.srcObj.IsDir(), got, c.wantSkip, c.why) + } + }) + } +} + +// TestMergeSkipDecision_EmptyDst: with no existedObjs, NOTHING is skipped +// regardless of obj kind — every src item gets a sub-task. +func TestMergeSkipDecision_EmptyDst(t *testing.T) { + existed := map[string]bool{} + for _, obj := range []model.Obj{ + fileObj("a.txt"), + dirObj("d1"), + fileObj("b.bin"), + dirObj("d2"), + } { + if shouldSkipForMerge(obj, existed) { + t.Errorf("nothing should be skipped against empty dst, but %s was", obj.GetName()) + } + } +} + +// TestMergeSkipDecision_DeepTreeRecursionContract simulates a 3-level src +// tree against a partial dst and asserts that the dir-recursion contract +// holds at every level. This is the "deep dir files missing" regression +// guard the user asked for. +func TestMergeSkipDecision_DeepTreeRecursionContract(t *testing.T) { + // Three levels of nesting, with files at each level. Dst already has + // the dir skeleton from a previous interrupted run, plus one file at + // the deepest level (simulating partial completion). + level0Src := []model.Obj{fileObj("root.txt"), dirObj("L1")} + level1Src := []model.Obj{fileObj("a.txt"), dirObj("L2")} + level2Src := []model.Obj{fileObj("deep1.txt"), fileObj("deep2.txt")} + + // Dst state per level + level0Dst := []model.Obj{fileObj("root.txt"), dirObj("L1")} // root.txt already uploaded + level1Dst := []model.Obj{dirObj("L2")} // L1 exists, no files yet + level2Dst := []model.Obj{fileObj("deep1.txt")} // deep1 already uploaded + + cases := []struct { + level string + srcObjs []model.Obj + dstObjs []model.Obj + mustSkip []string + mustSpawn []string + }{ + { + level: "L0", + srcObjs: level0Src, + dstObjs: level0Dst, + mustSkip: []string{"root.txt"}, // file already there + mustSpawn: []string{"L1"}, // dir must recurse + }, + { + level: "L1", + srcObjs: level1Src, + dstObjs: level1Dst, + mustSkip: []string{}, // a.txt not in dst + mustSpawn: []string{"a.txt", "L2"}, // upload file + recurse dir + }, + { + level: "L2", + srcObjs: level2Src, + dstObjs: level2Dst, + mustSkip: []string{"deep1.txt"}, // already uploaded + mustSpawn: []string{"deep2.txt"}, // must still upload + }, + } + + for _, c := range cases { + t.Run(c.level, func(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { return c.dstObjs, nil }) + existed, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst/"+c.level) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spawned, skipped []string + for _, obj := range c.srcObjs { + if shouldSkipForMerge(obj, existed) { + skipped = append(skipped, obj.GetName()) + } else { + spawned = append(spawned, obj.GetName()) + } + } + + if !sameStrings(skipped, c.mustSkip) { + t.Errorf("%s: skipped = %v, want %v", c.level, skipped, c.mustSkip) + } + if !sameStrings(spawned, c.mustSpawn) { + t.Errorf("%s: spawned = %v, want %v", c.level, spawned, c.mustSpawn) + } + }) + } +} + +func sameStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + m := map[string]int{} + for _, s := range a { + m[s]++ + } + for _, s := range b { + m[s]-- + } + for _, v := range m { + if v != 0 { + return false + } + } + return true +} + +// 12. Sanity: the helper invokes listDst exactly once with the given +// dstPath (no double-listing). +func TestExistingDstFilesFn_CallsListOnce(t *testing.T) { + rec := newListRecorder(func(string) ([]model.Obj, error) { + return []model.Obj{fileObj("a.txt")}, nil + }) + _, err := existingDstFilesFn(context.Background(), rec.listDst, "/dst/exact/path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(rec.calls) != 1 || rec.calls[0] != "/dst/exact/path" { + t.Fatalf("expected single call to /dst/exact/path, got %v", rec.calls) + } +} diff --git a/internal/hybrid_cache/buffer.go b/internal/hybrid_cache/buffer.go new file mode 100644 index 000000000..0023996e6 --- /dev/null +++ b/internal/hybrid_cache/buffer.go @@ -0,0 +1,96 @@ +package hybrid_cache + +import ( + "fmt" + "io" +) + +type BufferStore struct { + blocks [][]byte + size int64 +} + +func (m *BufferStore) Size() int64 { + return m.size +} + +// 用于存储不复用的[]byte +func (m *BufferStore) Append(buf []byte) { + m.size += int64(len(buf)) + m.blocks = append(m.blocks, buf) +} + +func (m *BufferStore) Close() error { + if len(m.blocks) > 0 { + clear(m.blocks) + m.blocks = m.blocks[:0] + m.size = 0 + } + return nil +} + +func (m *BufferStore) ReadAt(p []byte, off int64) (int, error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= m.size { + return 0, io.EOF + } + + var n int + for _, buf := range m.blocks { + if off >= int64(len(buf)) { + off -= int64(len(buf)) + continue + } + nn := copy(p[n:], buf[off:]) + n += nn + if n == len(p) { + return n, nil + } + off = 0 + } + + return n, io.EOF +} + +func (m *BufferStore) WriteAt(p []byte, off int64) (int, error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= m.size { + return 0, io.ErrShortWrite + } + + var n int + for _, b := range m.blocks { + if off >= int64(len(b)) { + off -= int64(len(b)) + continue + } + nn := copy(b[off:], p[n:]) + n += nn + if n == len(p) { + return n, nil + } + off = 0 + } + + return n, io.ErrShortWrite +} + +func (m *BufferStore) GrowTo(size int64) (err error) { + if size <= m.size { + return nil + } + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered in %v", r) + } + }() + m.blocks = append(m.blocks, make([]byte, size-m.size)) + m.size = size + return nil +} + +var _ BackingStore = (*BufferStore)(nil) diff --git a/pkg/buffer/bytes_test.go b/internal/hybrid_cache/buffer_test.go similarity index 67% rename from pkg/buffer/bytes_test.go rename to internal/hybrid_cache/buffer_test.go index 3f4d85563..439d7cb78 100644 --- a/pkg/buffer/bytes_test.go +++ b/internal/hybrid_cache/buffer_test.go @@ -1,26 +1,32 @@ -package buffer +package hybrid_cache_test import ( "errors" "io" "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/hybrid_cache" ) -func TestReader_ReadAt(t *testing.T) { +func TestBufferStore(t *testing.T) { type args struct { p []byte off int64 } - bs := &Reader{} + bs := &hybrid_cache.BufferStore{} bs.Append([]byte("github.com")) bs.Append([]byte("/OpenList")) - bs.Append([]byte("Team/")) - bs.Append([]byte("OpenList")) + bs.Append([]byte("Team/?")) + b := []byte("OpenList") + off := bs.Size() - 1 + _ = bs.GrowTo(off + int64(len(b))) + _, _ = bs.WriteAt(b, off) + tests := []struct { - name string - b *Reader - args args - want func(a args, n int, err error) error + name string + b *hybrid_cache.BufferStore + args args + check func(a args, n int, err error) error }{ { name: "readAt len 10 offset 0", @@ -29,7 +35,7 @@ func TestReader_ReadAt(t *testing.T) { p: make([]byte, 10), off: 0, }, - want: func(a args, n int, err error) error { + check: func(a args, n int, err error) error { if n != len(a.p) { return errors.New("read length not match") } @@ -49,7 +55,7 @@ func TestReader_ReadAt(t *testing.T) { p: make([]byte, 12), off: 11, }, - want: func(a args, n int, err error) error { + check: func(a args, n int, err error) error { if n != len(a.p) { return errors.New("read length not match") } @@ -69,7 +75,7 @@ func TestReader_ReadAt(t *testing.T) { p: make([]byte, 50), off: 24, }, - want: func(a args, n int, err error) error { + check: func(a args, n int, err error) error { if n != int(bs.Size()-a.off) { return errors.New("read length not match") } @@ -86,8 +92,8 @@ func TestReader_ReadAt(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.b.ReadAt(tt.args.p, tt.args.off) - if err := tt.want(tt.args, got, err); err != nil { - t.Errorf("Bytes.ReadAt() error = %v", err) + if err := tt.check(tt.args, got, err); err != nil { + t.Errorf("BufferStore.ReadAt() error = %v", err) } }) } diff --git a/internal/hybrid_cache/file.go b/internal/hybrid_cache/file.go new file mode 100644 index 000000000..bca197dad --- /dev/null +++ b/internal/hybrid_cache/file.go @@ -0,0 +1,179 @@ +package hybrid_cache + +import ( + "errors" + "io" + "os" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" +) + +type singleFileStore struct { + *os.File + size int64 +} + +func (s *singleFileStore) Size() int64 { + return s.size +} + +func (s *singleFileStore) GrowTo(size int64) error { + if size <= s.size { + return nil + } + err := s.File.Truncate(size) + if err == nil { + s.size = size + } + return err +} + +func (s *singleFileStore) Close() error { + err := s.File.Close() + _ = os.Remove(s.File.Name()) + return err +} + +type fileBlock struct { + file *os.File + size int64 + written int64 +} + +type MultiFileStore struct { + blocks []*fileBlock + size int64 +} + +func (s *MultiFileStore) Size() int64 { + return s.size +} + +func (m *MultiFileStore) Close() error { + var errs []error + for _, c := range m.blocks { + if err := c.file.Close(); err != nil { + errs = append(errs, err) + } + _ = os.Remove(c.file.Name()) + } + clear(m.blocks) + m.blocks = m.blocks[:0] + return errors.Join(errs...) +} + +func (m *MultiFileStore) GrowTo(size int64) error { + if size <= m.size { + return nil + } + f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return err + } + m.blocks = append(m.blocks, &fileBlock{file: f, size: size - m.size}) + m.size = size + return nil +} + +func (m *MultiFileStore) ReadAt(p []byte, off int64) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= m.size { + return 0, io.EOF + } + + for _, c := range m.blocks { + if off >= c.size { + off -= c.size + continue + } + + canRead := min(len(p)-n, int(c.size-off)) + if canRead <= 0 { + break + } + + filled := 0 + + if off < c.written { + fileReadable := min(canRead, int(c.written-off)) + nn, fileErr := c.file.ReadAt(p[n:n+fileReadable], off) + n += nn + filled = nn + if fileErr != nil && !errors.Is(fileErr, io.EOF) { + return n, fileErr + } + } + + if n == len(p) { + return n, nil + } + + if zeroFill := canRead - filled; zeroFill > 0 { + clear(p[n : n+zeroFill]) + n += zeroFill + } + + if n == len(p) { + return n, nil + } + off = 0 + } + + return n, io.EOF +} + +func (m *MultiFileStore) WriteAt(p []byte, off int64) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= m.size { + return 0, io.ErrShortWrite + } + + for _, b := range m.blocks { + if off >= b.size { + off -= b.size + continue + } + + canWrite := min(len(p)-n, int(b.size-off)) + if canWrite <= 0 { + break + } + + nn, fileErr := b.file.WriteAt(p[n:n+canWrite], off) + if end := off + int64(nn); end > b.written { + b.written = end + } + n += nn + if fileErr != nil { + return n, fileErr + } + if nn < canWrite { + return n, io.ErrShortWrite + } + if n == len(p) { + return n, nil + } + off = 0 + } + + return n, io.ErrShortWrite +} + +func NewFileStore(blockSize int64) (BackingStore, error) { + f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + err = f.Truncate(blockSize) + if err == nil { + return &singleFileStore{File: f, size: blockSize}, nil + } + return &MultiFileStore{ + blocks: []*fileBlock{{file: f, size: blockSize}}, + size: blockSize, + }, nil +} diff --git a/internal/hybrid_cache/file_test.go b/internal/hybrid_cache/file_test.go new file mode 100644 index 000000000..7f83625db --- /dev/null +++ b/internal/hybrid_cache/file_test.go @@ -0,0 +1,101 @@ +package hybrid_cache_test + +import ( + "bytes" + "errors" + "io" + "os" + "reflect" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/hybrid_cache" +) + +func TestFile(t *testing.T) { + f, err := os.CreateTemp("", "writeat-*") + if err != nil { + t.Error(err) + return + } + defer os.Remove(f.Name()) + defer f.Close() + t.Run("ReadAt", func(t *testing.T) { + _, err := f.ReadAt(make([]byte, 1), 20) + if err != nil && !errors.Is(err, io.EOF) { + t.Error(err) + } + }) + t.Run("WriteAt", func(t *testing.T) { + n, err := f.WriteAt([]byte("abc"), 20) + if err != nil { + t.Errorf("write n=%d err=%v", n, err) + return + } + stat, err := f.Stat() + if err != nil { + t.Errorf("stat err=%v", err) + return + } + if stat.Size() != 23 { + t.Fatalf("unexpected size: got %d want 23", stat.Size()) + } + + b := make([]byte, stat.Size()) + rn, rerr := f.ReadAt(b, 0) + if rn != len(b) || rerr != nil { + t.Fatalf("read n=%d err=%v", rn, rerr) + } + want := append(make([]byte, 20), []byte("abc")...) + if !reflect.DeepEqual(b, want) { + t.Fatalf("unexpected content: got %v want %v", b, want) + } + }) +} + +func TestMultiFileCache(t *testing.T) { + prevConf := conf.Conf + t.Cleanup(func() { + conf.Conf = prevConf + }) + conf.Conf = &conf.Config{} + f := hybrid_cache.MultiFileStore{} + defer f.Close() + t.Run("ReadAt", func(t *testing.T) { + _, err := f.ReadAt(make([]byte, 1), 20) + if err != nil && !errors.Is(err, io.EOF) { + t.Error(err) + } + }) + t.Run("WriteAt", func(t *testing.T) { + err := f.GrowTo(15) + if err != nil { + t.Errorf("truncate err=%v", err) + return + } + n, err := f.WriteAt([]byte("abc"), 10) + if err != nil { + t.Errorf("write n=%d err=%v", n, err) + return + } + + err = f.GrowTo(30) + if err != nil { + t.Errorf("truncate err=%v", err) + return + } + _, _ = f.WriteAt([]byte("123"), 15) + + b := append(make([]byte, 17), []byte("def")...) + b[0] = 'a' + rn, rerr := f.ReadAt(b, 8) + if rn != len(b) || rerr != nil { + t.Fatalf("read n=%d err=%v", rn, rerr) + } + want := []byte{0, 0, 'a', 'b', 'c', 0, 0, '1', '2', '3'} + want = append(want, make([]byte, 10)...) + if !bytes.Equal(b, want) { + t.Fatalf("unexpected content: got %v want %v", b, want) + } + }) +} diff --git a/internal/hybrid_cache/hybrid_cache.go b/internal/hybrid_cache/hybrid_cache.go new file mode 100644 index 000000000..d80f768f1 --- /dev/null +++ b/internal/hybrid_cache/hybrid_cache.go @@ -0,0 +1,240 @@ +package hybrid_cache + +import ( + "errors" + "io" + "runtime" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/mem" + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" +) + +// 线程不安全,单线程使用,或者外部加锁保护 +type HybridCache struct { + blockSize uint64 + memoryStore mem.LinearMemory + memoryOffset uint64 + backingStore BackingStore + backingOffset uint64 +} + +// HybridCache本身是一个大的Block,支持分块成多个小的Block + +// 分配一个新的Block,支持读写,大小为size +func (hc *HybridCache) AllocBlock(size uint64) (buffer.Block, error) { +retry: + if hc.backingStore != nil { + if err := hc.backingStore.GrowTo(int64(hc.backingOffset + size)); err != nil { + return nil, err + } + base := hc.backingOffset + hc.backingOffset += size + fs := buffer.NewBlockAdapter( + io.NewOffsetWriter(hc.backingStore, int64(base)), + io.NewSectionReader(hc.backingStore, int64(base), int64(size)), + ) + return fs, nil + } + all, err := hc.memoryStore.Reallocate(hc.memoryOffset + size) + if err == nil { + start := hc.memoryOffset + hc.memoryOffset += size + return buffer.NewByteBlock(all[start : start+size]), nil + } + if err2 := hc.initFileCache(); err2 != nil { + return nil, errors.Join(err, err2) + } + goto retry +} + +func (hc *HybridCache) allocWriteAtSeeker(size uint64) (buffer.WriteAtSeeker, error) { +retry: + if hc.backingStore != nil { + if err := hc.backingStore.GrowTo(int64(hc.backingOffset + size)); err != nil { + return nil, err + } + base := hc.backingOffset + hc.backingOffset += size + return io.NewOffsetWriter(hc.backingStore, int64(base)), nil + } + all, err := hc.memoryStore.Reallocate(hc.memoryOffset + size) + if err == nil { + start := hc.memoryOffset + hc.memoryOffset += size + return io.NewOffsetWriter(buffer.NewByteBlock(all[start:start+size]), 0), nil + } + if err2 := hc.initFileCache(); err2 != nil { + return nil, errors.Join(err, err2) + } + goto retry +} + +func (hc *HybridCache) NextBlock() (buffer.Block, error) { + return hc.AllocBlock(hc.blockSize) +} + +func (hc *HybridCache) RewindBySize(size uint64) { + if hc.backingOffset >= size { + hc.backingOffset -= size + return + } + size -= hc.backingOffset + hc.backingOffset = 0 + if hc.memoryOffset >= size { + hc.memoryOffset -= size + return + } + size -= hc.memoryOffset + hc.memoryOffset = 0 +} + +func (hc *HybridCache) RewindOneBlock() { + hc.RewindBySize(hc.blockSize) +} + +func (hc *HybridCache) initFileCache() error { + file, err := NewFileStore(int64(hc.blockSize)) + if err != nil { + return err + } + hc.backingStore = file + return nil +} + +func (hc *HybridCache) Close() error { + var err error + if hc.memoryStore != nil { + err = hc.memoryStore.Free() + hc.memoryStore = nil + hc.memoryOffset = 0 + } + if hc.backingStore != nil { + err = errors.Join(err, hc.backingStore.Close()) + hc.backingStore = nil + hc.backingOffset = 0 + } + return err +} + +func (hc *HybridCache) Size() int64 { + return int64(hc.memoryOffset + hc.backingOffset) +} + +func (hc *HybridCache) ReadAt(p []byte, off int64) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= hc.Size() { + return 0, io.EOF + } + + if off < int64(hc.memoryOffset) { + all, err := hc.memoryStore.Reallocate(min(hc.memoryOffset, uint64(off)+uint64(len(p)))) + if err != nil { + // 不可能失败 + panic(err) + } + n = copy(p, all[off:]) + if n == len(p) { + return n, nil + } + p = p[n:] + } + + off += int64(n) - int64(hc.memoryOffset) + canRead := int64(hc.backingOffset) - off + if canRead <= 0 { + return n, io.EOF + } + nn, err := hc.backingStore.ReadAt(p[:min(len(p), int(canRead))], off) + return n + nn, err +} + +func (hc *HybridCache) WriteAt(p []byte, off int64) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 || off >= hc.Size() { + return 0, io.ErrShortWrite + } + + if off < int64(hc.memoryOffset) { + all, err := hc.memoryStore.Reallocate(min(hc.memoryOffset, uint64(off)+uint64(len(p)))) + if err != nil { + // 不可能失败 + panic(err) + } + n = copy(all[off:], p) + if n == len(p) { + return n, nil + } + p = p[n:] + } + + off += int64(n) - int64(hc.memoryOffset) + canWrite := int64(hc.backingOffset) - off + if canWrite <= 0 { + return n, io.ErrShortWrite + } + nn, err := hc.backingStore.WriteAt(p[:min(len(p), int(canWrite))], off) + return n + nn, err +} + +func (hc *HybridCache) CopyFromN(src io.Reader, n int64) (written int64, err error) { + limit := n + for limit > 0 { + blockSize := limit + if hc.backingStore == nil && blockSize > int64(conf.MaxBlockLimit) { + blockSize = int64(conf.MaxBlockLimit) + } + b, err := hc.allocWriteAtSeeker(uint64(blockSize)) + if err != nil { + return written, err + } + nn, err := utils.CopyWithBufferN(b, src, blockSize) + written += nn + if nn != blockSize { + return written, err + } + limit -= nn + } + return written, nil +} + +// HybridCache 线程不安全,单线程使用,或者外部加锁保护 +func NewHybridCache(blockSize, maxMemorySize uint64) (hc *HybridCache, err error) { + if conf.MinFreeMemory > 0 { + // 策略1: Go自动内存管理 + if maxMemorySize <= conf.AutoMemoryLimit { + return &HybridCache{backingStore: &BufferStore{}, blockSize: blockSize}, nil + } + + // 策略2: 手动内存管理 + if maxMemorySize >= blockSize { + var m mem.LinearMemory + // 手动管理内存,Uinx Mmap 或者 Windows VirtualAlloc + if m, err = mem.NewGuardedMemory(blockSize, maxMemorySize); err == nil { + hc = &HybridCache{memoryStore: m, blockSize: blockSize} + } + } + } + // 策略3: 文件后备 + if hc == nil { + hc = &HybridCache{blockSize: blockSize} + // 文件 + if err2 := hc.initFileCache(); err2 != nil { + return nil, errors.Join(err, err2) + } + } + runtime.SetFinalizer(hc, func(hc *HybridCache) { + if hc.backingStore != nil { + _ = hc.backingStore.Close() + hc.backingStore = nil + } + }) + return hc, nil +} + +var _ buffer.Block = (*HybridCache)(nil) diff --git a/internal/hybrid_cache/type.go b/internal/hybrid_cache/type.go new file mode 100644 index 000000000..edc6d3deb --- /dev/null +++ b/internal/hybrid_cache/type.go @@ -0,0 +1,13 @@ +package hybrid_cache + +import ( + "io" + + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" +) + +type BackingStore interface { + buffer.Block + io.Closer + GrowTo(size int64) error +} diff --git a/internal/mem/mem_other.go b/internal/mem/mem_other.go new file mode 100644 index 000000000..2535c4570 --- /dev/null +++ b/internal/mem/mem_other.go @@ -0,0 +1,25 @@ +//go:build !unix && !windows + +package mem + +func NewMemory(cap, max uint64) (LinearMemory, error) { + return &sliceMemory{buf: make([]byte, 0, cap)}, nil +} + +type sliceMemory struct { + buf []byte +} + +func (b *sliceMemory) Free() error { + b.buf = nil + return nil +} + +func (b *sliceMemory) Reallocate(size uint64) ([]byte, error) { + if cap := uint64(cap(b.buf)); size > cap { + b.buf = append(b.buf[:cap], make([]byte, size-cap)...) + } else { + b.buf = b.buf[:size] + } + return b.buf, nil +} diff --git a/internal/mem/mem_unix.go b/internal/mem/mem_unix.go new file mode 100644 index 000000000..bd1979298 --- /dev/null +++ b/internal/mem/mem_unix.go @@ -0,0 +1,92 @@ +//go:build unix + +package mem + +import ( + "math" + + "golang.org/x/sys/unix" +) + +func NewMemory(cap, max uint64) (LinearMemory, error) { + // Round up to the page size. + rnd := uint64(unix.Getpagesize() - 1) + res := (max + rnd) &^ rnd + + if res > math.MaxInt { + // This ensures int(res) overflows to a negative value, + // and unix.Mmap returns EINVAL. + res = math.MaxUint64 + } + + com := res + prot := unix.PROT_READ | unix.PROT_WRITE + if cap < max { // Commit memory only if cap=max. + com = 0 + prot = unix.PROT_NONE + } + + // Reserve res bytes of address space, to ensure we won't need to move it. + // A protected, private, anonymous mapping should not commit memory. + b, err := unix.Mmap(-1, 0, int(res), prot, unix.MAP_PRIVATE|unix.MAP_ANON) + if err != nil { + return nil, err + } + return &mmappedMemory{buf: b[:com]}, nil +} + +// The slice covers the entire mmapped memory: +// - len(buf) is the already committed memory, +// - cap(buf) is the reserved address space. +type mmappedMemory struct { + buf []byte + growCheck GrowCheck +} + +func (m *mmappedMemory) SetGrowCheck(c GrowCheck) { + m.growCheck = c +} + +func (m *mmappedMemory) Reallocate(size uint64) ([]byte, error) { + com := uint64(len(m.buf)) + res := uint64(cap(m.buf)) + if com < size { + if size <= res { + // Grow geometrically, round up to the page size. + rnd := uint64(unix.Getpagesize() - 1) + new := com + com>>3 + new = min(max(size, new), res) + new = (new + rnd) &^ rnd + + if m.growCheck != nil { + if err := m.growCheck(new - com); err != nil { + return nil, err + } + } + + // Commit additional memory up to new bytes. + err := unix.Mprotect(m.buf[com:new], unix.PROT_READ|unix.PROT_WRITE) + if err != nil { + return nil, err + } + + m.buf = m.buf[:new] // Update committed memory. + } else { + return nil, ErrNotEnoughMemory + } + } + // Limit returned capacity because bytes beyond + // len(m.buf) have not yet been committed. + return m.buf[:size:len(m.buf)], nil +} + +func (m *mmappedMemory) Free() error { + if m.buf != nil { + err := unix.Munmap(m.buf[:cap(m.buf)]) + if err != nil { + return err + } + m.buf = nil + } + return nil +} diff --git a/internal/mem/mem_windows.go b/internal/mem/mem_windows.go new file mode 100644 index 000000000..e7a4bb27e --- /dev/null +++ b/internal/mem/mem_windows.go @@ -0,0 +1,94 @@ +package mem + +import ( + "math" + "unsafe" + + "golang.org/x/sys/windows" +) + +func NewMemory(cap, max uint64) (LinearMemory, error) { + // Round up to the page size. + rnd := uint64(windows.Getpagesize() - 1) + res := (max + rnd) &^ rnd + + if res > math.MaxInt { + // This ensures uintptr(res) overflows to a large value, + // and windows.VirtualAlloc returns an error. + res = math.MaxUint64 + } + + com := res + kind := windows.MEM_COMMIT + if cap < max { // Commit memory only if cap=max. + com = 0 + kind = windows.MEM_RESERVE + } + + // Reserve res bytes of address space, to ensure we won't need to move it. + r, err := windows.VirtualAlloc(0, uintptr(res), uint32(kind), windows.PAGE_READWRITE) + if err != nil { + return nil, err + } + + buf := unsafe.Slice((*byte)(unsafe.Pointer(r)), int(res)) + return &virtualMemory{addr: r, buf: buf[:com]}, nil +} + +// The slice covers the entire mmapped memory: +// - len(buf) is the already committed memory, +// - cap(buf) is the reserved address space. +type virtualMemory struct { + buf []byte + addr uintptr + growCheck GrowCheck +} + +func (m *virtualMemory) SetGrowCheck(c GrowCheck) { + m.growCheck = c +} + +func (m *virtualMemory) Reallocate(size uint64) ([]byte, error) { + com := uint64(len(m.buf)) + res := uint64(cap(m.buf)) + if com < size { + if size <= res { + // Grow geometrically, round up to the page size. + rnd := uint64(windows.Getpagesize() - 1) + new := com + com>>3 + new = min(max(size, new), res) + new = (new + rnd) &^ rnd + + if m.growCheck != nil { + if err := m.growCheck(new - com); err != nil { + return nil, err + } + } + + // Commit additional memory up to new bytes. + _, err := windows.VirtualAlloc(m.addr, uintptr(new), windows.MEM_COMMIT, windows.PAGE_READWRITE) + if err != nil { + return nil, err + } + + m.buf = m.buf[:new] // Update committed memory. + } else { + return nil, ErrNotEnoughMemory + } + } + // Limit returned capacity because bytes beyond + // len(m.buf) have not yet been committed. + return m.buf[:size:len(m.buf)], nil +} + +func (m *virtualMemory) Free() error { + if m.addr != 0 { + err := windows.VirtualFree(m.addr, 0, windows.MEM_RELEASE) + if err != nil { + return err + } + m.addr = 0 + m.buf = nil + } + return nil +} diff --git a/internal/mem/type.go b/internal/mem/type.go new file mode 100644 index 000000000..3ba8e35ca --- /dev/null +++ b/internal/mem/type.go @@ -0,0 +1,9 @@ +package mem + +type LinearMemory interface { + // 线程不安全 + Reallocate(size uint64) (all []byte, err error) + Free() error +} + +type GrowCheck func(growSize uint64) error diff --git a/internal/mem/utils.go b/internal/mem/utils.go new file mode 100644 index 000000000..21944066d --- /dev/null +++ b/internal/mem/utils.go @@ -0,0 +1,79 @@ +package mem + +import ( + "errors" + "fmt" + "runtime" + "sync/atomic" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/pkg/singleflight" + "github.com/shirou/gopsutil/v4/mem" +) + +var ErrNotEnoughMemory = errors.New("not enough memory") + +func MemoryGrowCheck(growSize uint64) error { + if conf.MinFreeMemory == 0 { + return ErrNotEnoughMemory + } + m, err, _ := singleflight.AnyGroup.Do("MemoryGrowCheck", func() (any, error) { + m, err := mem.VirtualMemory() + if err != nil { + return nil, err + } + if m.Available < conf.MinFreeMemory { + return nil, ErrNotEnoughMemory + } + return m, nil + }) + if err != nil { + return err + } + memStat := m.(*mem.VirtualMemoryStat) + for { + available := atomic.LoadUint64(&memStat.Available) + if available < growSize || available-growSize < conf.MinFreeMemory { + return ErrNotEnoughMemory + } + if atomic.CompareAndSwapUint64(&memStat.Available, available, available-growSize) { + return nil + } + } +} + +func NewGuardedMemory(cap, max uint64) (m LinearMemory, err error) { + if err := MemoryGrowCheck(cap); err != nil { + return nil, err + } + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%w: %v", ErrNotEnoughMemory, r) + } + }() + m, err = NewMemory(cap, max) + if err != nil { + return nil, err + } + if s, ok := m.(interface{ SetGrowCheck(GrowCheck) }); ok { + s.SetGrowCheck(MemoryGrowCheck) + } + gm := &guardedMemory{m} + runtime.SetFinalizer(gm, func(gm *guardedMemory) { + gm.Free() + }) + return gm, nil +} + +type guardedMemory struct { + LinearMemory +} + +func (s *guardedMemory) Reallocate(size uint64) (all []byte, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%w: %v", ErrNotEnoughMemory, r) + } + }() + return s.LinearMemory.Reallocate(size) +} diff --git a/internal/model/args.go b/internal/model/args.go index 073c94a63..d165908fb 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -25,6 +25,10 @@ type LinkArgs struct { Redirect bool } +// LinkRefresher is a callback function type for refreshing download links +// It returns a new Link and the associated object, or an error +type LinkRefresher func(ctx context.Context) (*Link, Obj, error) + type Link struct { URL string `json:"url"` // most common way Header http.Header `json:"header"` // needed header (for url) @@ -37,6 +41,10 @@ type Link struct { PartSize int `json:"part_size"` ContentLength int64 `json:"content_length"` // 转码视频、缩略图 + // Refresher is a callback to refresh the link when it expires during long downloads + // This field is not serialized and is optional - if nil, no refresh will be attempted + Refresher LinkRefresher `json:"-"` + utils.SyncClosers `json:"-"` // 如果SyncClosers中的资源被关闭后Link将不可用,则此值应为 true RequireReference bool `json:"-"` diff --git a/internal/model/file.go b/internal/model/file.go index 4ca7201e1..d6697cd0e 100644 --- a/internal/model/file.go +++ b/internal/model/file.go @@ -7,24 +7,26 @@ import ( // File is basic file level accessing interface type File interface { - io.Reader io.ReaderAt - io.Seeker + io.ReadSeeker +} +type FileWriter interface { + io.WriterAt + io.WriteSeeker } type FileCloser struct { File io.Closer } -func (f *FileCloser) Close() error { - var errs []error +func (f *FileCloser) Close() (err error) { if clr, ok := f.File.(io.Closer); ok { - errs = append(errs, clr.Close()) + err = clr.Close() } if f.Closer != nil { - errs = append(errs, f.Closer.Close()) + return errors.Join(err, f.Closer.Close()) } - return errors.Join(errs...) + return } // FileRangeReader 是对 RangeReaderIF 的轻量包装,表明由 RangeReaderIF.RangeRead diff --git a/internal/net/oss.go b/internal/net/oss.go index a897161f1..b28e06bc4 100644 --- a/internal/net/oss.go +++ b/internal/net/oss.go @@ -1,9 +1,36 @@ package net -import "github.com/aliyun/aliyun-oss-go-sdk/oss" +import ( + "crypto/tls" + stdnet "net" + "net/http" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/aliyun/aliyun-oss-go-sdk/oss" +) func NewOSSClient(endpoint, accessKeyID, accessKeySecret string, options ...oss.ClientOption) (*oss.Client, error) { - clientOptions := []oss.ClientOption{oss.HTTPClient(NewHttpClient())} + clientOptions := []oss.ClientOption{oss.HTTPClient(NewOSSUploadHttpClient())} clientOptions = append(clientOptions, options...) return oss.New(endpoint, accessKeyID, accessKeySecret, clientOptions...) } + +func NewOSSUploadHttpClient() *http.Client { + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, + DialContext: (&stdnet.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ResponseHeaderTimeout: 5 * time.Minute, + } + + SetProxyIfConfigured(transport) + + return &http.Client{ + Timeout: time.Hour * 48, + Transport: transport, + } +} diff --git a/internal/net/oss_test.go b/internal/net/oss_test.go index 9001cd39d..13a4f86fe 100644 --- a/internal/net/oss_test.go +++ b/internal/net/oss_test.go @@ -8,6 +8,21 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" ) +func TestNewOSSUploadHttpClientHasLongerTimeout(t *testing.T) { + oldConf := conf.Conf + conf.Conf = conf.DefaultConfig("data") + defer func() { conf.Conf = oldConf }() + + client := NewOSSUploadHttpClient() + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", client.Transport) + } + if transport.ResponseHeaderTimeout < 120_000_000_000 { // 2 minutes minimum + t.Fatalf("ResponseHeaderTimeout=%v, want >= 2m for upload", transport.ResponseHeaderTimeout) + } +} + func TestNewOSSClientUsesEnvironmentHTTPSProxy(t *testing.T) { oldConf := conf.Conf conf.Conf = conf.DefaultConfig("data") diff --git a/internal/net/request.go b/internal/net/request.go index e1f045120..4fcdad0fd 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -5,17 +5,20 @@ import ( "errors" "fmt" "io" + "math/rand/v2" "net/http" stdpath "path" "strconv" "sync" + "sync/atomic" "time" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" + hcache "github.com/OpenListTeam/OpenList/v4/internal/hybrid_cache" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/rclone/rclone/lib/mmap" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/aws/aws-sdk-go/aws/awsutil" @@ -86,8 +89,8 @@ func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readClo if impl.cfg.PartSize == 0 { impl.cfg.PartSize = DefaultDownloadPartSize } - if conf.MaxBufferLimit > 0 && impl.cfg.PartSize > conf.MaxBufferLimit { - impl.cfg.PartSize = conf.MaxBufferLimit + if conf.MinFreeMemory > 0 && impl.cfg.PartSize > int(conf.MaxBlockLimit) { + impl.cfg.PartSize = int(conf.MaxBlockLimit) } if impl.cfg.HttpClient == nil { impl.cfg.HttpClient = DefaultHttpRequestFunc @@ -102,65 +105,57 @@ type downloader struct { cancel context.CancelCauseFunc cfg Downloader - params *HttpRequestParams //http request params - chunkChannel chan chunk //chunk chanel + params *HttpRequestParams //http request params + chunkCh chan chunk //chunk chanel //wg sync.WaitGroup - m sync.Mutex + mu sync.Mutex nextChunk int //next chunk id - bufs []*Buf + bufMap map[int]*buffer.PipeBuffer written int64 //total bytes of file downloaded from remote - err error concurrency int //剩余的并发数,递减。到0时停止并发 - maxPart int //有多少个分片 pos int64 maxPos int64 - m2 sync.Mutex - readingID int // 正在被读取的id + delayMu sync.Mutex + readingID int64 // 正在被读取的id + + hc *hcache.HybridCache } type ConcurrencyLimit struct { - _m sync.Mutex - Limit int // 需要大于0 + mu sync.Mutex + + Limit uint32 } var ErrExceedMaxConcurrency = HttpStatusCodeError(http.StatusTooManyRequests) -func (l *ConcurrencyLimit) sub() error { - l._m.Lock() - defer l._m.Unlock() - if l.Limit-1 < 0 { +func (l *ConcurrencyLimit) Acquire() error { + if l == nil { + return nil + } + l.mu.Lock() + defer l.mu.Unlock() + if l.Limit == 0 { return ErrExceedMaxConcurrency } l.Limit-- - // log.Debugf("ConcurrencyLimit.sub: %d", l.Limit) return nil } -func (l *ConcurrencyLimit) add() { - l._m.Lock() - defer l._m.Unlock() - l.Limit++ - // log.Debugf("ConcurrencyLimit.add: %d", l.Limit) -} - -// 检测是否超过限制 -func (d *downloader) concurrencyCheck() error { - if d.cfg.ConcurrencyLimit != nil { - return d.cfg.ConcurrencyLimit.sub() - } - return nil -} -func (d *downloader) concurrencyFinish() { - if d.cfg.ConcurrencyLimit != nil { - d.cfg.ConcurrencyLimit.add() +func (l *ConcurrencyLimit) Release() { + if l == nil { + return } + l.mu.Lock() + l.Limit++ + l.mu.Unlock() } // download performs the implementation of the object download across ranged GETs. func (d *downloader) download() (io.ReadCloser, error) { - if err := d.concurrencyCheck(); err != nil { + if err := d.cfg.ConcurrencyLimit.Acquire(); err != nil { return nil, err } @@ -176,16 +171,16 @@ func (d *downloader) download() (io.ReadCloser, error) { if maxPart == 1 { resp, err := d.cfg.HttpClient(d.ctx, d.params) if err != nil { - d.concurrencyFinish() + d.cfg.ConcurrencyLimit.Release() return nil, err } closeFunc := resp.Body.Close resp.Body = utils.NewReadCloser(resp.Body, func() error { - d.m.Lock() - defer d.m.Unlock() + d.mu.Lock() + defer d.mu.Unlock() var err error if closeFunc != nil { - d.concurrencyFinish() + d.cfg.ConcurrencyLimit.Release() err = closeFunc() closeFunc = nil } @@ -196,103 +191,131 @@ func (d *downloader) download() (io.ReadCloser, error) { d.ctx, d.cancel = context.WithCancelCause(d.ctx) // workers - d.chunkChannel = make(chan chunk, d.cfg.Concurrency) + d.chunkCh = make(chan chunk, d.cfg.Concurrency) - d.maxPart = maxPart d.pos = d.params.Range.Start d.maxPos = d.params.Range.Start + d.params.Range.Length d.concurrency = d.cfg.Concurrency - _ = d.sendChunkTask(true) - var rc io.ReadCloser = NewMultiReadCloser(d.bufs[0], d.interrupt, d.finishBuf) + var err error + d.hc, err = hcache.NewHybridCache(uint64(d.cfg.PartSize), uint64(d.params.Range.Length)) + if err == nil { + d.bufMap = make(map[int]*buffer.PipeBuffer, d.cfg.Concurrency) + err = d.sendChunkTask(true) + } + if err != nil { + d.cancel(err) + d.cfg.ConcurrencyLimit.Release() + return nil, d.interrupt() + } - // Return error - return rc, d.err + d.mu.Lock() + defer d.mu.Unlock() + return &multiReadCloser{d: d, curBuf: d.popBuf(0), maxPos: maxPart}, nil } -func (d *downloader) sendChunkTask(newConcurrency bool) error { - d.m.Lock() - defer d.m.Unlock() - isNewBuf := d.concurrency > 0 +func (d *downloader) sendChunkTask(newConcurrency bool) (err error) { + d.mu.Lock() + defer d.mu.Unlock() + if d.pos >= d.maxPos { + return nil + } if newConcurrency { if d.concurrency <= 0 { return nil } if d.nextChunk > 0 { // 第一个不检查,因为已经检查过了 - if err := d.concurrencyCheck(); err != nil { + if err := d.cfg.ConcurrencyLimit.Acquire(); err != nil { return err } + defer func() { + if err != nil { + d.cfg.ConcurrencyLimit.Release() + } + }() } - d.concurrency-- - go d.downloadPart() } - var buf *Buf - if isNewBuf { - buf = NewBuf(d.ctx, d.cfg.PartSize) - d.bufs = append(d.bufs, buf) - } else { - buf = d.getBuf(d.nextChunk) - } - - if d.pos < d.maxPos { - finalSize := int64(d.cfg.PartSize) - switch d.nextChunk { - case 0: - // 最小分片在前面有助视频播放? - firstSize := d.params.Range.Length % finalSize - if firstSize > 0 { - minSize := finalSize / 2 - if firstSize < minSize { // 最小分片太小就调整到一半 - finalSize = minSize - } else { - finalSize = firstSize - } - } - case 1: - firstSize := d.params.Range.Length % finalSize - minSize := finalSize / 2 - if firstSize > 0 && firstSize < minSize { - finalSize += firstSize - minSize - } - } - err := buf.Reset(int(finalSize)) + br := d.bufMap[d.nextChunk] + if br == nil { + var b buffer.Block + b, err = d.hc.NextBlock() if err != nil { return err } - ch := chunk{ - start: d.pos, - size: finalSize, - id: d.nextChunk, - buf: buf, + br = buffer.NewPipeBuffer(d.ctx, b) + d.bufMap[d.nextChunk] = br + } - newConcurrency: newConcurrency, + finalSize := int64(d.cfg.PartSize) + switch d.nextChunk { + case 0: + // 最小分片在前面有助视频播放? + firstSize := d.params.Range.Length % finalSize + if firstSize > 0 { + minSize := finalSize / 2 + // 最小分片太小就调整到一半 + finalSize = max(firstSize, minSize) } - d.pos += finalSize - d.nextChunk++ - d.chunkChannel <- ch + case 1: + firstSize := d.params.Range.Length % finalSize + minSize := finalSize / 2 + if firstSize > 0 && firstSize < minSize { + finalSize += firstSize - minSize + } + } + err = br.Reset(int(finalSize)) + if err != nil { + return err // 分片算法错误或者下载中断 + } + if newConcurrency { + go d.downloadPart() + d.concurrency-- + } + ch := chunk{ + start: d.pos, + size: finalSize, + id: d.nextChunk, + buf: br, + + newConcurrency: newConcurrency, + } + d.pos += finalSize + d.nextChunk++ + select { + case <-d.ctx.Done(): + return context.Cause(d.ctx) + case d.chunkCh <- ch: return nil } - return nil } // when the final reader Close, we interrupt func (d *downloader) interrupt() error { - d.m.Lock() - defer d.m.Unlock() - err := d.err - if err == nil && d.written != d.params.Range.Length { - log.Debugf("Downloader interrupt before finish") - err := fmt.Errorf("interrupted") - d.err = err - } - close(d.chunkChannel) - if d.bufs != nil { - d.cancel(err) - for _, buf := range d.bufs { - buf.Close() - } - d.bufs = nil + err := context.Cause(d.ctx) + if err == nil { + if atomic.LoadInt64(&d.written) != d.params.Range.Length { + err = fmt.Errorf("interrupted") + } + } else if errors.Is(err, context.Canceled) { + err = nil + } + d.cancel(err) + d.mu.Lock() + defer d.mu.Unlock() + if d.bufMap != nil { + for _, buf := range d.bufMap { + _ = buf.Close() + } + d.bufMap = nil + } + if d.hc != nil { + _ = d.hc.Close() + d.hc = nil + } + if d.maxPos != 0 { + d.maxPos = 0 + close(d.chunkCh) if d.concurrency > 0 { d.concurrency = -d.concurrency } @@ -300,49 +323,49 @@ func (d *downloader) interrupt() error { } return err } -func (d *downloader) getBuf(id int) (b *Buf) { - return d.bufs[id%len(d.bufs)] +func (d *downloader) popBuf(id int) *buffer.PipeBuffer { + br := d.bufMap[id] + if br != nil { + delete(d.bufMap, id) + d.bufMap[-1] = br // -1 保存最后一次取出的buf用来关闭 + } + return br } -func (d *downloader) finishBuf(id int) (isLast bool, nextBuf *Buf) { - id++ - if id >= d.maxPart { - return true, nil + +func (d *downloader) finishBuf(nextId int, prev *buffer.PipeBuffer) (next *buffer.PipeBuffer) { + atomic.StoreInt64(&d.readingID, int64(nextId)) + + d.mu.Lock() + shouldSendTask := d.bufMap[d.nextChunk] == nil + if shouldSendTask { + d.bufMap[d.nextChunk] = prev } + d.mu.Unlock() - _ = d.sendChunkTask(false) + if shouldSendTask { + _ = d.sendChunkTask(false) + } else { + _ = prev.Close() + } - d.readingID = id - return false, d.getBuf(id) + d.mu.Lock() + defer d.mu.Unlock() + return d.popBuf(nextId) } // downloadPart is an individual goroutine worker reading from the ch channel // and performing Http request on the data with a given byte range. func (d *downloader) downloadPart() { - defer d.concurrencyFinish() + defer d.cfg.ConcurrencyLimit.Release() for { select { case <-d.ctx.Done(): return - case c, ok := <-d.chunkChannel: + case c, ok := <-d.chunkCh: if !ok { return } - if d.getErr() != nil { - // Drain the channel if there is an error, to prevent deadlocking - // of download producer. - return - } - if err := d.downloadChunk(&c); err != nil { - if err == errCancelConcurrency { - return - } - if err == context.Canceled { - if e := context.Cause(d.ctx); e != nil { - err = e - } - } - d.setErr(err) - d.cancel(err) + if !d.downloadChunk(&c) { return } } @@ -350,26 +373,20 @@ func (d *downloader) downloadPart() { } // downloadChunk downloads the chunk -func (d *downloader) downloadChunk(ch *chunk) error { +func (d *downloader) downloadChunk(ch *chunk) bool { log.Debugf("start chunk_%d, %+v", ch.id, ch) params := d.getParamsFromChunk(ch) - var n int64 var err error for retry := 0; retry <= d.cfg.PartBodyMaxRetries; retry++ { - if d.getErr() != nil { - return nil - } + var n int64 n, err = d.tryDownloadChunk(params, ch) if err == nil { d.incrWritten(n) log.Debugf("chunk_%d downloaded", ch.id) - break + return true } - if d.getErr() != nil { - return nil - } - if utils.IsCanceled(d.ctx) { - return d.ctx.Err() + if d.ctx.Err() != nil { + return false } // Check if the returned error is an errNeedRetry. // If this occurs we unwrap the err to set the underlying error @@ -389,17 +406,31 @@ func (d *downloader) downloadChunk(ch *chunk) error { ch.id, params.URL, retry, err) } else if err == errInfiniteRetry { retry-- - continue + } else if err == errCancelConcurrency { + return false // 取消一个的并发 } else { break } } + if err != nil { + d.cancel(err) // 取消所有的并发 + } + return false +} - return err +func (d *downloader) delay(ti time.Duration) bool { + t := time.NewTimer(ti) + select { + case <-d.ctx.Done(): + t.Stop() + return false + case <-t.C: + return true + } } -var errCancelConcurrency = errors.New("cancel concurrency") -var errInfiniteRetry = errors.New("infinite retry") +var errCancelConcurrency = errors.New("") +var errInfiniteRetry = errors.New("") func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { resp, err := d.cfg.HttpClient(d.ctx, params) @@ -420,15 +451,17 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int case http.StatusServiceUnavailable: case http.StatusGatewayTimeout: } - <-time.After(time.Millisecond * 200) - return 0, &errNeedRetry{err: err} + if !d.delay(time.Millisecond * time.Duration(rand.Uint32N(300)+200)) { + return 0, errCancelConcurrency + } + return 0, &errNeedRetry{err} } // 来到这 说明第1个分片下载 连接成功了 // 后续分片下载出错都当超载处理 log.Debugf("err chunk_%d, try downloading:%v", ch.id, err) - d.m.Lock() + d.mu.Lock() isCancelConcurrency := ch.newConcurrency if d.concurrency > 0 { // 取消剩余的并发任务 // 用于计算实际的并发数 @@ -437,18 +470,25 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int } if isCancelConcurrency { d.concurrency-- - d.chunkChannel <- *ch - d.m.Unlock() - return 0, errCancelConcurrency + d.mu.Unlock() + select { + case <-d.ctx.Done(): + return 0, errCancelConcurrency + case d.chunkCh <- *ch: + return 0, errCancelConcurrency + } } - d.m.Unlock() - if ch.id != d.readingID { //正在被读取的优先重试 - d.m2.Lock() - defer d.m2.Unlock() - <-time.After(time.Millisecond * 200) + d.mu.Unlock() + if int64(ch.id) != atomic.LoadInt64(&d.readingID) { //正在被读取的优先重试 + d.delayMu.Lock() + defer d.delayMu.Unlock() + if !d.delay(time.Millisecond * time.Duration(rand.Uint32N(300)+200)) { + return 0, errCancelConcurrency + } } return 0, errInfiniteRetry } + defer resp.Body.Close() //only check file size on the first task if ch.id == 0 { @@ -461,11 +501,11 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int n, err := utils.CopyWithBuffer(ch.buf, resp.Body) if err != nil { - return n, &errNeedRetry{err: err} + return n, &errNeedRetry{err} } if n != ch.size { err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n) - return n, &errNeedRetry{err: err} + return n, &errNeedRetry{err} } return n, nil @@ -497,7 +537,8 @@ func (d *downloader) checkTotalBytes(resp *http.Response) error { totalStr := stdpath.Base(contentRange) if totalStr != "*" { - if total, err := strconv.ParseInt(totalStr, 10, 64); err != nil { + var total int64 + if total, err = strconv.ParseInt(totalStr, 10, 64); err != nil { err = fmt.Errorf("failed extracting file size: %s", totalStr) } else { totalBytes = total @@ -510,35 +551,12 @@ func (d *downloader) checkTotalBytes(resp *http.Response) error { if totalBytes != d.params.Size && err == nil { err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes) } - if err != nil { - d.setErr(err) - d.cancel(err) - } return err } func (d *downloader) incrWritten(n int64) { - d.m.Lock() - defer d.m.Unlock() - - d.written += n -} - -// getErr is a thread-safe getter for the error object -func (d *downloader) getErr() error { - d.m.Lock() - defer d.m.Unlock() - - return d.err -} - -// setErr is a thread-safe setter for the error object -func (d *downloader) setErr(e error) { - d.m.Lock() - defer d.m.Unlock() - - d.err = e + atomic.AddInt64(&d.written, n) } // Chunk represents a single chunk of data to write by the worker routine. @@ -548,7 +566,7 @@ func (d *downloader) setErr(e error) { type chunk struct { start int64 size int64 - buf *Buf + buf *buffer.PipeBuffer id int newConcurrency bool @@ -587,196 +605,39 @@ type HttpRequestParams struct { Size int64 } type errNeedRetry struct { - err error -} - -func (e *errNeedRetry) Error() string { - return e.err.Error() + error } func (e *errNeedRetry) Unwrap() error { - return e.err -} - -type MultiReadCloser struct { - cfg *cfg - closer closerFunc - finish finishBufFUnc -} - -type cfg struct { - rPos int //current reader position, start from 0 - curBuf *Buf + return e.error } -type closerFunc func() error -type finishBufFUnc func(id int) (isLast bool, buf *Buf) - -// NewMultiReadCloser to save memory, we re-use limited Buf, and feed data to Read() -func NewMultiReadCloser(buf *Buf, c closerFunc, fb finishBufFUnc) *MultiReadCloser { - return &MultiReadCloser{closer: c, finish: fb, cfg: &cfg{curBuf: buf}} +type multiReadCloser struct { + pos int //current reader position, start from 0 + maxPos int + curBuf *buffer.PipeBuffer + d *downloader } -func (mr MultiReadCloser) Read(p []byte) (n int, err error) { - if mr.cfg.curBuf == nil { +func (mr *multiReadCloser) Read(p []byte) (n int, err error) { + if mr.curBuf == nil { return 0, io.EOF } - n, err = mr.cfg.curBuf.Read(p) - //log.Debugf("read_%d read current buffer, n=%d ,err=%+v", mr.cfg.rPos, n, err) + n, err = mr.curBuf.Read(p) + // log.Debugf("read_%d read current buffer, n=%d ,err=%+v", mr.rPos, n, err) if err == io.EOF { - log.Debugf("read_%d finished current buffer", mr.cfg.rPos) + log.Debugf("read_%d finished current buffer", mr.pos) - isLast, next := mr.finish(mr.cfg.rPos) - if isLast { + mr.pos++ + if mr.pos >= mr.maxPos { return n, io.EOF } - mr.cfg.curBuf = next - mr.cfg.rPos++ + mr.curBuf = mr.d.finishBuf(mr.pos, mr.curBuf) return n, nil } - if err == context.Canceled { - if e := context.Cause(mr.cfg.curBuf.ctx); e != nil { - err = e - } - } return n, err } -func (mr MultiReadCloser) Close() error { - return mr.closer() -} - -type Buf struct { - size int //expected size - ctx context.Context - offR int - offW int - rw sync.Mutex - buf []byte - mmap bool - - readSignal chan struct{} - readPending bool -} - -// NewBuf is a buffer that can have 1 read & 1 write at the same time. -// when read is faster write, immediately feed data to read after written -func NewBuf(ctx context.Context, maxSize int) *Buf { - br := &Buf{ - ctx: ctx, - size: maxSize, - readSignal: make(chan struct{}, 1), - } - if conf.MmapThreshold > 0 && maxSize >= conf.MmapThreshold { - m, err := mmap.Alloc(maxSize) - if err == nil { - br.buf = m - br.mmap = true - return br - } - } - br.buf = make([]byte, maxSize) - return br -} - -func (br *Buf) Reset(size int) error { - br.rw.Lock() - defer br.rw.Unlock() - if br.buf == nil { - return io.ErrClosedPipe - } - if size > cap(br.buf) { - return fmt.Errorf("reset size %d exceeds max size %d", size, cap(br.buf)) - } - br.size = size - br.offR = 0 - br.offW = 0 - return nil -} -func (br *Buf) Read(p []byte) (int, error) { - if err := br.ctx.Err(); err != nil { - return 0, err - } - if len(p) == 0 { - return 0, nil - } - if br.offR >= br.size { - return 0, io.EOF - } - for { - br.rw.Lock() - if br.buf == nil { - br.rw.Unlock() - return 0, io.ErrClosedPipe - } - - if br.offW < br.offR { - br.rw.Unlock() - return 0, io.ErrUnexpectedEOF - } - if br.offW == br.offR { - br.readPending = true - br.rw.Unlock() - select { - case <-br.ctx.Done(): - return 0, br.ctx.Err() - case _, ok := <-br.readSignal: - if !ok { - return 0, io.ErrClosedPipe - } - continue - } - } - - n := copy(p, br.buf[br.offR:br.offW]) - br.offR += n - br.rw.Unlock() - if n < len(p) && br.offR >= br.size { - return n, io.EOF - } - return n, nil - } -} - -func (br *Buf) Write(p []byte) (int, error) { - if err := br.ctx.Err(); err != nil { - return 0, err - } - if len(p) == 0 { - return 0, nil - } - br.rw.Lock() - defer br.rw.Unlock() - if br.buf == nil { - return 0, io.ErrClosedPipe - } - if br.offW >= br.size { - return 0, io.ErrShortWrite - } - n := copy(br.buf[br.offW:], p[:min(br.size-br.offW, len(p))]) - br.offW += n - if br.readPending { - br.readPending = false - select { - case br.readSignal <- struct{}{}: - default: - } - } - if n < len(p) { - return n, io.ErrShortWrite - } - return n, nil -} - -func (br *Buf) Close() error { - br.rw.Lock() - defer br.rw.Unlock() - var err error - if br.mmap { - err = mmap.Free(br.buf) - br.mmap = false - } - br.buf = nil - close(br.readSignal) - return err +func (mr *multiReadCloser) Close() error { + return mr.d.interrupt() } diff --git a/internal/net/request_test.go b/internal/net/request_test.go index da16a3165..0fdc56eb3 100644 --- a/internal/net/request_test.go +++ b/internal/net/request_test.go @@ -11,13 +11,12 @@ import ( "net/http" "sync" "testing" + "time" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/sirupsen/logrus" ) -var buf22MB = make([]byte, 1024*1024*22) - func containsString(slice []string, val string) bool { for _, item := range slice { if item == val { @@ -27,18 +26,6 @@ func containsString(slice []string, val string) bool { return false } -func dummyHttpRequest(data []byte, p http_range.Range) io.ReadCloser { - - end := p.Start + p.Length - 1 - - if end >= int64(len(data)) { - end = int64(len(data)) - } - - bodyBytes := data[p.Start:end] - return io.NopCloser(bytes.NewReader(bodyBytes)) -} - func TestDownloadOrder(t *testing.T) { buff := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} downloader, invocations, ranges := newDownloadRangeClient(buff) @@ -67,8 +54,8 @@ func TestDownloadOrder(t *testing.T) { if err != nil { t.Fatalf("expect no error, got %v", err) } - if exp, a := int(length), len(resultBuf); exp != a { - t.Errorf("expect buffer length=%d, got %d", exp, a) + if exp, a := buff[start:start+length2], resultBuf; !bytes.Equal(exp, a) { + t.Errorf("expect buffer %v, got %v", exp, a) } chunkSize := int(length+int64(partSize)-1) / partSize if e, a := chunkSize, *invocations; e != a { @@ -84,7 +71,100 @@ func TestDownloadOrder(t *testing.T) { if e, a := expectRngs, *ranges; len(e) != len(a) { t.Errorf("expect %v ranges, got %v", e, a) } + if err := readCloser.Close(); err != nil { + t.Errorf("expect no error on close, got %v", err) + } } + +func TestDownloadInterrupt(t *testing.T) { + buff := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + buff = append(buff, buff...) + downloader, _, _ := newDownloadRangeClient(buff) + con, partSize := 6, 3 + d := NewDownloader(func(d *Downloader) { + d.Concurrency = con + d.PartSize = partSize + d.HttpClient = downloader.HttpRequest + d.ConcurrencyLimit = &ConcurrencyLimit{ + Limit: 5, + } + }) + + var start, length int64 = 0, int64(len(buff)) + req := &HttpRequestParams{ + Range: http_range.Range{Start: start, Length: length}, + Size: int64(len(buff)), + } + ctx, cancel := context.WithCancel(context.Background()) + readCloser, err := d.Download(ctx, req) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + _, err = io.CopyN(io.Discard, readCloser, 8) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + cancel() + if err := readCloser.Close(); err != nil { + t.Errorf("expect no error on close, got %v", err) + } +} + +func TestHighConcurrency(t *testing.T) { + buff := make([]byte, 8<<10) + for i := range len(buff) { + buff[i] = byte(i % 256) + } + downloader, invocations, _ := newDownloadRangeClient(buff) + con, partSize := 64, 100 + concurrencyLimit := uint32(32) + d := NewDownloader(func(d *Downloader) { + d.Concurrency = con + d.PartSize = partSize + d.HttpClient = downloader.HttpRequest + d.ConcurrencyLimit = &ConcurrencyLimit{ + Limit: concurrencyLimit, + } + }) + + var start, length int64 = 2, 7 << 10 + length2 := length + if length2 == -1 { + length2 = int64(len(buff)) - start + } + req := &HttpRequestParams{ + Range: http_range.Range{Start: start, Length: length}, + Size: int64(len(buff)), + } + readCloser, err := d.Download(context.Background(), req) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + resultBuf, err := io.ReadAll(readCloser) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if !bytes.Equal(buff[start:start+length2], resultBuf) { + t.Error("expect buffer content matches, but got mismatch") + } + chunkSize := int(length+int64(partSize)-1) / partSize + if e, a := chunkSize, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + if err := readCloser.Close(); err != nil { + t.Errorf("expect no error on close, got %v", err) + } + for range 100 { + time.Sleep(10 * time.Millisecond) + if d.ConcurrencyLimit.Limit == concurrencyLimit { + return + } + } + t.Errorf("expect concurrency limit to be %v, got %v", concurrencyLimit, d.ConcurrencyLimit.Limit) +} + func init() { Formatter := new(logrus.TextFormatter) Formatter.TimestampFormat = "2006-01-02T15:04:05.999999999" @@ -136,6 +216,9 @@ func TestDownloadSingle(t *testing.T) { if e, a := expectRngs, *ranges; len(e) != len(a) { t.Errorf("expect %v ranges, got %v", e, a) } + if err := readCloser.Close(); err != nil { + t.Errorf("expect no error on close, got %v", err) + } } type downloadCaptureClient struct { diff --git a/internal/net/serve.go b/internal/net/serve.go index 6a20460b1..ee288b86a 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "mime/multipart" + stdnet "net" // 标准库net包,用于Dialer "net/http" "strconv" "strings" @@ -286,12 +287,20 @@ func NewHttpClient() *http.Client { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, + // 快速连接超时:10秒建立连接,失败快速重试 + DialContext: (&stdnet.Dialer{ + Timeout: 10 * time.Second, // TCP握手超时 + KeepAlive: 30 * time.Second, // TCP keep-alive + }).DialContext, + // 响应头超时:15秒等待服务器响应头(平衡API调用与下载检测) + ResponseHeaderTimeout: 15 * time.Second, + // 允许长时间读取数据(无 IdleConnTimeout 限制) } SetProxyIfConfigured(transport) return &http.Client{ - Timeout: time.Hour * 48, + Timeout: time.Hour * 48, // 总超时保持48小时(允许大文件慢速下载) Transport: transport, } } diff --git a/internal/offline_download/115/client.go b/internal/offline_download/115/client.go index 9e9f702d5..019de65af 100644 --- a/internal/offline_download/115/client.go +++ b/internal/offline_download/115/client.go @@ -127,7 +127,7 @@ func (p *Cloud115) Status(task *tool.DownloadTask) (*tool.Status, error) { s.Completed = t.IsDone() s.TotalBytes = t.Size if t.IsFailed() { - s.Err = fmt.Errorf(t.GetStatus()) + s.Err = fmt.Errorf("%s", t.GetStatus()) } return s, nil } diff --git a/internal/offline_download/115_open/client.go b/internal/offline_download/115_open/client.go index d12e02ec5..56669a674 100644 --- a/internal/offline_download/115_open/client.go +++ b/internal/offline_download/115_open/client.go @@ -2,8 +2,16 @@ package _115_open import ( "context" + "encoding/base32" + "encoding/hex" + "errors" "fmt" + "net/url" + "strconv" + "strings" + "time" + sdk "github.com/OpenListTeam/115-sdk-go" _115_open "github.com/OpenListTeam/OpenList/v4/drivers/115_open" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/setting" @@ -12,11 +20,34 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/offline_download/tool" "github.com/OpenListTeam/OpenList/v4/internal/op" + log "github.com/sirupsen/logrus" ) type Open115 struct { } +type offlineTaskClient interface { + OfflineDownload(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) + OfflineList(ctx context.Context) (*sdk.OfflineTaskListResp, error) + DeleteOfflineTask(ctx context.Context, infoHash string, deleteFiles bool) error +} + +type offlineTaskDetailClient interface { + OfflineDownloadWithDetails(ctx context.Context, uris []string, dstDir model.Obj) ([]string, []sdk.AddOfflineTaskURIsResp, string, error) +} + +type offlineTaskLimiter interface { + WaitLimit(ctx context.Context) error +} + +func waitOfflineTaskLimit(ctx context.Context, client offlineTaskClient) error { + limiter, ok := client.(offlineTaskLimiter) + if !ok { + return nil + } + return limiter.WaitLimit(ctx) +} + func (o *Open115) Name() string { return "115 Open" } @@ -68,15 +99,473 @@ func (o *Open115) AddURL(args *tool.AddUrlArgs) (string, error) { if err != nil { return "", err } + log.Infof("[115_open] AddURL start: temp_dir=%q actual_path=%q parent_id=%q parent_name=%q url=%q", args.TempDir, actualPath, parentDir.GetID(), parentDir.GetName(), args.Url) + logOfflineURLDetails("[115_open] AddURL input", args.Url) - hashs, err := driver115Open.OfflineDownload(ctx, []string{args.Url}, parentDir) - if err != nil || len(hashs) < 1 { - return "", fmt.Errorf("failed to add offline download task: %w", err) + hashs, err := addOfflineDownloadTask(ctx, driver115Open, args.Url, parentDir) + if err != nil { + return "", err + } + + if len(hashs) < 1 { + return "", fmt.Errorf("failed to add offline download task: no task hash returned") } return hashs[0], nil } +func addOfflineDownloadTask(ctx context.Context, client offlineTaskClient, url string, parentDir model.Obj) ([]string, error) { + parentID, parentName := "", "" + if parentDir != nil { + parentID = parentDir.GetID() + parentName = parentDir.GetName() + } + log.Infof("[115_open] addOfflineDownloadTask: parent_id=%q parent_name=%q url=%q", parentID, parentName, url) + logOfflineURLDetails("[115_open] addOfflineDownloadTask target", url) + if err := preCleanDuplicateOfflineTasks(ctx, client, url); err != nil { + return nil, err + } + hashs, addItems, rawResp, err := offlineDownloadWithDetails(ctx, client, url, parentDir) + log.Infof("[115_open] addOfflineDownloadTask first attempt result: hashes=%v err=%v add_items=%d", hashs, err, len(addItems)) + if err == nil { + return hashs, nil + } + if !isDuplicateOfflineTaskError(err) { + return nil, fmt.Errorf("failed to add offline download task: %w", err) + } + log.Infof("[115_open] duplicate offline task detected, trying cleanup before retry") + if rawResp != "" { + log.Infof("[115_open] duplicate add response: %s", rawResp) + } + for _, item := range addItems { + log.Infof("[115_open] duplicate add item: state=%v code=%d info_hash=%q url=%q", item.State, item.Code, item.InfoHash, item.URL) + logOfflineURLDetails("[115_open] duplicate add item url", item.URL) + if item.InfoHash == "" { + log.Infof("[115_open] skipping add-response duplicate item: empty info_hash") + continue + } + if item.URL != "" && !offlineTaskURLMatches(item.URL, url) { + log.Infof("[115_open] skipping add-response duplicate item: url mismatch") + continue + } + log.Infof("[115_open] deleting duplicate task directly from add response: info_hash=%s url=%s", item.InfoHash, item.URL) + if err := waitOfflineTaskLimit(ctx, client); err != nil { + return nil, err + } + if deleteErr := client.DeleteOfflineTask(ctx, item.InfoHash, false); deleteErr != nil { + log.Errorf("[115_open] delete duplicate task from add response failed: info_hash=%s err=%v", item.InfoHash, deleteErr) + return nil, fmt.Errorf("failed to delete duplicate offline download task from add response: %w", deleteErr) + } + log.Infof("[115_open] delete duplicate task from add response success: info_hash=%s", item.InfoHash) + waitForOfflineTaskRemoval(ctx, client, item.InfoHash) + hashs, retryItems, retryRawResp, retryErr := offlineDownloadWithDetails(ctx, client, url, parentDir) + log.Infof("[115_open] retry add after add-response delete: hashes=%v err=%v add_items=%d", hashs, retryErr, len(retryItems)) + if retryRawResp != "" { + log.Infof("[115_open] retry add raw response after add-response delete: %s", retryRawResp) + } + err = retryErr + if err != nil { + return nil, fmt.Errorf("failed to add offline download task after removing duplicate: %w", err) + } + return hashs, nil + } + if err := waitOfflineTaskLimit(ctx, client); err != nil { + return nil, err + } + taskList, listErr := client.OfflineList(ctx) + if listErr != nil || taskList == nil { + return nil, fmt.Errorf("failed to add offline download task: %w", err) + } + log.Infof("[115_open] offline list returned %d tasks across %d pages", len(taskList.Tasks), taskList.PageCount) + for _, task := range taskList.Tasks { + matched, reason := offlineTaskMatchReason(task, url) + log.Infof("[115_open] duplicate candidate: info_hash=%s status=%d size=%d name=%q url=%q matched=%v reason=%s", task.InfoHash, task.Status, task.Size, task.Name, task.URL, matched, reason) + logOfflineURLDetails("[115_open] duplicate candidate url", task.URL) + if !matched { + continue + } + log.Infof("[115_open] matched duplicate offline task: info_hash=%s, name=%s", task.InfoHash, task.Name) + log.Infof("[115_open] deleting matched duplicate offline task: info_hash=%s status=%d size=%d", task.InfoHash, task.Status, task.Size) + if err := waitOfflineTaskLimit(ctx, client); err != nil { + return nil, err + } + if deleteErr := client.DeleteOfflineTask(ctx, task.InfoHash, false); deleteErr != nil { + log.Errorf("[115_open] delete matched duplicate offline task failed: info_hash=%s err=%v", task.InfoHash, deleteErr) + return nil, fmt.Errorf("failed to delete duplicate offline download task: %w", deleteErr) + } + log.Infof("[115_open] delete matched duplicate offline task success: info_hash=%s", task.InfoHash) + waitForOfflineTaskRemoval(ctx, client, task.InfoHash) + hashs, retryItems, retryRawResp, retryErr := offlineDownloadWithDetails(ctx, client, url, parentDir) + log.Infof("[115_open] retry add after matched delete: hashes=%v err=%v add_items=%d", hashs, retryErr, len(retryItems)) + if retryRawResp != "" { + log.Infof("[115_open] retry add raw response after matched delete: %s", retryRawResp) + } + err = retryErr + if err != nil { + return nil, fmt.Errorf("failed to add offline download task after removing duplicate: %w", err) + } + return hashs, nil + } + log.Warnf("[115_open] duplicate offline task detected but no matching task found in offline list") + return nil, fmt.Errorf("failed to add offline download task: %w", err) +} + +func preCleanDuplicateOfflineTasks(ctx context.Context, client offlineTaskClient, url string) error { + if err := waitOfflineTaskLimit(ctx, client); err != nil { + return err + } + taskList, listErr := client.OfflineList(ctx) + if listErr != nil || taskList == nil { + log.Warnf("[115_open] pre-add offline list failed: err=%v", listErr) + return nil + } + log.Infof("[115_open] pre-add offline list returned %d tasks across %d pages", len(taskList.Tasks), taskList.PageCount) + deleted := 0 + for _, task := range taskList.Tasks { + matched, reason := offlineTaskMatchReason(task, url) + log.Infof("[115_open] pre-add duplicate candidate: info_hash=%s status=%d size=%d name=%q url=%q matched=%v reason=%s", task.InfoHash, task.Status, task.Size, task.Name, task.URL, matched, reason) + logOfflineURLDetails("[115_open] pre-add duplicate candidate url", task.URL) + if !matched { + continue + } + log.Infof("[115_open] pre-add deleting matched duplicate offline task: info_hash=%s status=%d size=%d", task.InfoHash, task.Status, task.Size) + if err := waitOfflineTaskLimit(ctx, client); err != nil { + return err + } + if deleteErr := client.DeleteOfflineTask(ctx, task.InfoHash, false); deleteErr != nil { + log.Errorf("[115_open] pre-add delete matched duplicate offline task failed: info_hash=%s err=%v", task.InfoHash, deleteErr) + return fmt.Errorf("failed to delete duplicate offline download task: %w", deleteErr) + } + deleted++ + log.Infof("[115_open] pre-add delete matched duplicate offline task success: info_hash=%s", task.InfoHash) + waitForOfflineTaskRemoval(ctx, client, task.InfoHash) + } + if deleted == 0 { + log.Infof("[115_open] pre-add duplicate scan found no matches") + } + return nil +} + +func offlineDownloadWithDetails(ctx context.Context, client offlineTaskClient, url string, parentDir model.Obj) ([]string, []sdk.AddOfflineTaskURIsResp, string, error) { + if err := waitOfflineTaskLimit(ctx, client); err != nil { + return nil, nil, "", err + } + if detailClient, ok := client.(offlineTaskDetailClient); ok { + return detailClient.OfflineDownloadWithDetails(ctx, []string{url}, parentDir) + } + hashs, err := client.OfflineDownload(ctx, []string{url}, parentDir) + return hashs, nil, "", err +} + +func isDuplicateOfflineTaskError(err error) bool { + if err == nil { + return false + } + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "10008") || + strings.Contains(errStr, "重复") || + strings.Contains(errStr, "已存在") || + strings.Contains(errStr, "duplicate") +} + +func offlineTaskURLMatches(taskURL string, rawURL string) bool { + taskVariants := normalizedOfflineTaskURLVariants(taskURL) + rawVariants := normalizedOfflineTaskURLVariants(rawURL) + for candidate := range taskVariants { + if _, ok := rawVariants[candidate]; ok { + return true + } + } + return false +} + +func offlineTaskMatches(task sdk.OfflineTask, rawURL string) bool { + matched, _ := offlineTaskMatchReason(task, rawURL) + return matched +} + +func offlineTaskMatchReason(task sdk.OfflineTask, rawURL string) (bool, string) { + if offlineTaskURLMatches(task.URL, rawURL) { + return true, "url variants matched" + } + if httpURLMatches(task.URL, rawURL) { + return true, "http url host+path matched" + } + rawMagnet := parseMagnetBTIH(rawURL) + if rawMagnet != "" { + taskHash := normalizeInfoHash(task.InfoHash) + if taskHash != "" && taskHash == rawMagnet { + return true, "task info_hash matched raw magnet" + } + taskURLHash := parseMagnetBTIH(task.URL) + if taskURLHash != "" && taskURLHash == rawMagnet { + return true, "task url magnet hash matched" + } + return false, fmt.Sprintf("task magnet hash mismatch: task_info_hash=%q task_url_hash=%q raw_hash=%q", taskHash, taskURLHash, rawMagnet) + } + taskED2K, rawED2K := parseED2KLink(task.URL), parseED2KLink(rawURL) + if taskED2K != nil && rawED2K != nil { + if taskED2K.Hash == rawED2K.Hash { + if taskED2K.Size == rawED2K.Size { + return true, "task url ed2k hash matched" + } + return true, "task url ed2k hash matched despite size mismatch" + } + return false, fmt.Sprintf("task url ed2k mismatch: task=%s raw=%s", taskED2K.String(), rawED2K.String()) + } + if rawED2K == nil { + return false, "raw url is not ed2k and url variants did not match" + } + if normalizeOfflineTaskURL(task.InfoHash) == rawED2K.Hash { + if task.Size == rawED2K.Size { + return true, "task info_hash matched raw ed2k" + } + return true, "task info_hash matched raw ed2k despite size mismatch" + } + taskName := normalizeOfflineTaskURL(task.Name) + if taskName == normalizeOfflineTaskURL(rawED2K.Name) && task.Size == rawED2K.Size { + return true, "task name and size matched raw ed2k" + } + return false, fmt.Sprintf("task name/hash/size mismatch: task_name=%q raw_name=%q task_info_hash=%q raw_hash=%q task_size=%d raw_size=%d", taskName, normalizeOfflineTaskURL(rawED2K.Name), normalizeOfflineTaskURL(task.InfoHash), rawED2K.Hash, task.Size, rawED2K.Size) +} + +func normalizedOfflineTaskURLVariants(raw string) map[string]struct{} { + variants := map[string]struct{}{} + queue := []string{raw} + for len(queue) > 0 { + current := normalizeOfflineTaskURL(queue[0]) + queue = queue[1:] + if current == "" { + continue + } + if _, ok := variants[current]; ok { + continue + } + variants[current] = struct{}{} + if decoded, err := url.QueryUnescape(current); err == nil && decoded != current { + queue = append(queue, decoded) + } + if decoded, err := url.PathUnescape(current); err == nil && decoded != current { + queue = append(queue, decoded) + } + } + return variants +} + +func httpURLMatches(taskURL, rawURL string) bool { + taskNormalized := normalizeHTTPURL(taskURL) + rawNormalized := normalizeHTTPURL(rawURL) + if taskNormalized == "" || rawNormalized == "" { + return false + } + return taskNormalized == rawNormalized +} + +func normalizeHTTPURL(raw string) string { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil || parsed == nil { + return "" + } + scheme := strings.ToLower(parsed.Scheme) + if scheme != "http" && scheme != "https" { + return "" + } + host := strings.ToLower(parsed.Host) + if host == "" { + return "" + } + path := strings.ToLower(parsed.Path) + path = strings.TrimSuffix(path, "/") + if path == "" { + path = "/" + } + return fmt.Sprintf("%s://%s%s", scheme, host, path) +} + +func normalizeOfflineTaskURL(raw string) string { + normalized := strings.TrimSpace(raw) + if normalized == "" { + return "" + } + normalized = strings.TrimSuffix(normalized, "/") + return strings.ToLower(normalized) +} + +type ed2kLink struct { + Name string + Size int64 + Hash string +} + +func parseED2KLink(raw string) *ed2kLink { + normalized := strings.TrimSpace(raw) + if !strings.HasPrefix(strings.ToLower(normalized), "ed2k://|file|") { + return nil + } + parts := strings.Split(normalized, "|") + if len(parts) < 6 { + return nil + } + name, err := url.PathUnescape(parts[2]) + if err != nil { + name = parts[2] + } + size, err := strconv.ParseInt(parts[3], 10, 64) + if err != nil { + return nil + } + return &ed2kLink{ + Name: normalizeOfflineTaskURL(name), + Size: size, + Hash: normalizeOfflineTaskURL(parts[4]), + } +} + +func parseMagnetBTIH(raw string) string { + if raw == "" { + return "" + } + lower := strings.ToLower(raw) + idx := strings.Index(lower, "btih:") + if idx == -1 { + return "" + } + candidate := raw[idx+len("btih:"):] + if candidate == "" { + return "" + } + for i, ch := range candidate { + if ch == '&' || ch == '#' || ch == '/' { + candidate = candidate[:i] + break + } + } + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return "" + } + if decoded, err := url.QueryUnescape(candidate); err == nil { + candidate = decoded + } + return normalizeInfoHash(candidate) +} + +func normalizeInfoHash(raw string) string { + normalized := strings.TrimSpace(raw) + if normalized == "" { + return "" + } + normalized = strings.TrimPrefix(strings.ToLower(normalized), "urn:btih:") + if normalized == "" { + return "" + } + if len(normalized) == 40 && isHexString(normalized) { + return normalized + } + if len(normalized) == 32 && isBase32String(normalized) { + decoded, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(normalized)) + if err == nil && len(decoded) == 20 { + return hex.EncodeToString(decoded) + } + } + return normalized +} + +func isHexString(value string) bool { + for _, ch := range value { + if (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') { + continue + } + return false + } + return true +} + +func isBase32String(value string) bool { + for _, ch := range value { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '2' && ch <= '7') { + continue + } + return false + } + return true +} + +func (e *ed2kLink) Equal(other *ed2kLink) bool { + if e == nil || other == nil { + return false + } + return e.Name == other.Name && e.Size == other.Size && e.Hash == other.Hash +} + +func (e *ed2kLink) String() string { + if e == nil { + return "" + } + return fmt.Sprintf("name=%q size=%d hash=%q", e.Name, e.Size, e.Hash) +} + +func logOfflineURLDetails(prefix string, raw string) { + if raw == "" { + log.Infof("%s details: raw is empty", prefix) + return + } + variants := normalizedOfflineTaskURLVariants(raw) + log.Infof("%s details: raw=%q normalized_variants=%v", prefix, raw, mapKeys(variants)) + if parsedMagnet := parseMagnetBTIH(raw); parsedMagnet != "" { + log.Infof("%s details: parsed_magnet_hash=%s", prefix, parsedMagnet) + } + if parsed := parseED2KLink(raw); parsed != nil { + log.Infof("%s details: parsed_ed2k=%s", prefix, parsed.String()) + } +} + +func mapKeys(values map[string]struct{}) []string { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + return keys +} + +func waitForOfflineTaskRemoval(ctx context.Context, client offlineTaskClient, infoHash string) { + const maxChecks = 3 + for attempt := 1; attempt <= maxChecks; attempt++ { + if err := waitOfflineTaskLimit(ctx, client); err != nil { + log.Warnf("[115_open] post-delete wait limit failed: info_hash=%s attempt=%d err=%v", infoHash, attempt, err) + return + } + taskList, err := client.OfflineList(ctx) + if err != nil { + log.Warnf("[115_open] post-delete check failed: info_hash=%s attempt=%d err=%v", infoHash, attempt, err) + return + } + stillExists := false + taskStatus := -999 + taskName := "" + for _, task := range taskList.Tasks { + if normalizeOfflineTaskURL(task.InfoHash) != normalizeOfflineTaskURL(infoHash) { + continue + } + stillExists = true + taskStatus = task.Status + taskName = task.Name + break + } + log.Infof("[115_open] post-delete check: info_hash=%s attempt=%d exists=%v status=%d name=%q task_count=%d", infoHash, attempt, stillExists, taskStatus, taskName, len(taskList.Tasks)) + if !stillExists { + return + } + if attempt < maxChecks { + select { + case <-ctx.Done(): + return + case <-time.After(time.Second): + } + } + } +} + func (o *Open115) Remove(task *tool.DownloadTask) error { storage, _, err := op.GetStorageAndActualPath(task.TempDir) if err != nil { @@ -124,13 +613,15 @@ func (o *Open115) Status(task *tool.DownloadTask) (*tool.Status, error) { s.Completed = t.IsDone() s.TotalBytes = t.Size if t.IsFailed() { - s.Err = fmt.Errorf(t.GetStatus()) + s.Err = errors.New(t.GetStatus()) } return s, nil } } - s.Err = fmt.Errorf("the task has been deleted") - return nil, nil + // 任务不在列表中,可能已完成或被删除 + s.Progress = 100 + s.Completed = true + return s, nil } var _ tool.Tool = (*Open115)(nil) diff --git a/internal/offline_download/115_open/client_test.go b/internal/offline_download/115_open/client_test.go new file mode 100644 index 000000000..2e8b94778 --- /dev/null +++ b/internal/offline_download/115_open/client_test.go @@ -0,0 +1,616 @@ +package _115_open + +import ( + "context" + "fmt" + "strings" + "testing" + + sdk "github.com/OpenListTeam/115-sdk-go" + "github.com/OpenListTeam/OpenList/v4/internal/model" +) + +type mockOfflineTaskClient struct { + offlineDownloadFunc func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) + offlineDownloadWithDetailsFunc func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, []sdk.AddOfflineTaskURIsResp, string, error) + offlineListFunc func(ctx context.Context) (*sdk.OfflineTaskListResp, error) + deleteOfflineFunc func(ctx context.Context, infoHash string, deleteFiles bool) error + waitLimitFunc func(ctx context.Context) error + waitLimitCalls int +} + +func (m *mockOfflineTaskClient) OfflineDownload(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + return m.offlineDownloadFunc(ctx, uris, dstDir) +} + +func (m *mockOfflineTaskClient) OfflineDownloadWithDetails(ctx context.Context, uris []string, dstDir model.Obj) ([]string, []sdk.AddOfflineTaskURIsResp, string, error) { + if m.offlineDownloadWithDetailsFunc == nil { + hashes, err := m.OfflineDownload(ctx, uris, dstDir) + return hashes, nil, "", err + } + return m.offlineDownloadWithDetailsFunc(ctx, uris, dstDir) +} + +func (m *mockOfflineTaskClient) OfflineList(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + return m.offlineListFunc(ctx) +} + +func (m *mockOfflineTaskClient) DeleteOfflineTask(ctx context.Context, infoHash string, deleteFiles bool) error { + return m.deleteOfflineFunc(ctx, infoHash, deleteFiles) +} + +func (m *mockOfflineTaskClient) WaitLimit(ctx context.Context) error { + m.waitLimitCalls++ + if m.waitLimitFunc != nil { + return m.waitLimitFunc(ctx) + } + return nil +} + +func TestIsDuplicateOfflineTaskError(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "code 10008", err: fmt.Errorf("code: 10008"), want: true}, + {name: "chinese duplicate", err: fmt.Errorf("任务重复"), want: true}, + {name: "already exists", err: fmt.Errorf("任务已存在"), want: true}, + {name: "english duplicate", err: fmt.Errorf("duplicate task"), want: true}, + {name: "other", err: fmt.Errorf("network timeout"), want: false}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isDuplicateOfflineTaskError(tc.err); got != tc.want { + t.Fatalf("want %v, got %v", tc.want, got) + } + }) + } +} + +func TestOfflineTaskURLMatches(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + taskURL string + rawURL string + want bool + }{ + { + name: "exact match", + taskURL: "ed2k://|file|test.avi|123|ABC|/", + rawURL: "ed2k://|file|test.avi|123|ABC|/", + want: true, + }, + { + name: "percent encoded file name", + taskURL: "ed2k://|file|[AVS]Azumi Mizushima [ネオパンストフェティッシュ Ver.19 水嶋あずみ](NOP-019)(2011.01.13).avi|1593601796|9E5CCC55541BD46EE8252BF100EFC46D|/", + rawURL: "ed2k://|file|[AVS]Azumi%20Mizushima%20[ネオパンストフェティッシュ%20Ver.19%20水嶋あずみ](NOP-019)(2011.01.13).avi|1593601796|9E5CCC55541BD46EE8252BF100EFC46D|/", + want: true, + }, + { + name: "case and trailing slash normalized", + taskURL: "ED2K://|FILE|TEST.AVI|123|ABC|", + rawURL: "ed2k://|file|test.avi|123|abc|/", + want: true, + }, + { + name: "different link", + taskURL: "ed2k://|file|a.avi|123|ABC|/", + rawURL: "ed2k://|file|b.avi|123|ABC|/", + want: false, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := offlineTaskURLMatches(tc.taskURL, tc.rawURL); got != tc.want { + t.Fatalf("want %v, got %v", tc.want, got) + } + }) + } +} + +func TestOfflineTaskMatches(t *testing.T) { + t.Parallel() + + t.Run("match ed2k by parsed fields when task url differs", func(t *testing.T) { + t.Parallel() + + task := sdk.OfflineTask{ + InfoHash: "server-task-hash", + Name: "[AVS]Azumi Mizushima [ネオパンストフェティッシュ Ver.19 水嶋あずみ](NOP-019)(2011.01.13).avi", + Size: 1593601796, + URL: "", + } + rawURL := "ed2k://|file|[AVS]Azumi%20Mizushima%20[ネオパンストフェティッシュ%20Ver.19%20水嶋あずみ](NOP-019)(2011.01.13).avi|1593601796|9E5CCC55541BD46EE8252BF100EFC46D|/" + + if !offlineTaskMatches(task, rawURL) { + t.Fatal("expected task to match by ed2k parsed fields") + } + }) + + t.Run("do not match different ed2k size", func(t *testing.T) { + t.Parallel() + + task := sdk.OfflineTask{ + Name: "[AVS]Azumi Mizushima [ネオパンストフェティッシュ Ver.19 水嶋あずみ](NOP-019)(2011.01.13).avi", + Size: 1, + } + rawURL := "ed2k://|file|[AVS]Azumi%20Mizushima%20[ネオパンストフェティッシュ%20Ver.19%20水嶋あずみ](NOP-019)(2011.01.13).avi|1593601796|9E5CCC55541BD46EE8252BF100EFC46D|/" + + if offlineTaskMatches(task, rawURL) { + t.Fatal("expected task not to match") + } + }) + + t.Run("magnet still matches by url", func(t *testing.T) { + t.Parallel() + + rawURL := "magnet:?xt=urn:btih:1234567890ABCDEF1234567890ABCDEF12345678&dn=test" + task := sdk.OfflineTask{ + InfoHash: "1234567890abcdef1234567890abcdef12345678", + URL: rawURL, + } + + if !offlineTaskMatches(task, rawURL) { + t.Fatal("expected magnet task to match by url") + } + }) + + t.Run("match magnet by btih despite noisy tracker", func(t *testing.T) { + t.Parallel() + + rawURL := "magnet:?xt=urn:btih:1234567890ABCDEF1234567890ABCDEF12345678&dn=test" + task := sdk.OfflineTask{ + InfoHash: "1234567890abcdef1234567890abcdef12345678", + URL: "magnet:?xt=urn:btih:1234567890ABCDEF1234567890ABCDEF12345678&dn=test&tr=%3C!DOCTYPE%20html%3E", + } + + if !offlineTaskMatches(task, rawURL) { + t.Fatal("expected magnet task to match by btih") + } + }) + + t.Run("match http by host and path", func(t *testing.T) { + t.Parallel() + + rawURL := "https://example.com/files/test.mp4" + task := sdk.OfflineTask{ + URL: "https://EXAMPLE.com/files/test.mp4?token=abc", + } + + if !offlineTaskMatches(task, rawURL) { + t.Fatal("expected http task to match by host and path") + } + }) +} + +func TestAddOfflineDownloadTask(t *testing.T) { + t.Parallel() + + const ( + testURL = "https://example.com/test.torrent" + firstHash = "hash-1" + staleHash = "hash-stale" + deleteError = "delete failed" + ) + + t.Run("success on first try", func(t *testing.T) { + t.Parallel() + + listCount := 0 + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + return []string{firstHash}, nil + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{Tasks: nil}, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + t.Fatal("DeleteOfflineTask should not be called") + return nil + }, + } + + hashes, err := addOfflineDownloadTask(context.Background(), client, testURL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(hashes) != 1 || hashes[0] != firstHash { + t.Fatalf("unexpected hashes: %+v", hashes) + } + if listCount < 1 { + t.Fatalf("want pre-add offline list call, got %d", listCount) + } + }) + + t.Run("wait limit applied for add flow", func(t *testing.T) { + t.Parallel() + + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + return []string{firstHash}, nil + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + return &sdk.OfflineTaskListResp{Tasks: nil}, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + return nil + }, + } + + _, err := addOfflineDownloadTask(context.Background(), client, testURL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client.waitLimitCalls < 2 { + t.Fatalf("want wait limit calls >= 2, got %d", client.waitLimitCalls) + } + }) + + t.Run("delete duplicate and retry", func(t *testing.T) { + t.Parallel() + + callCount := 0 + deleteCount := 0 + listCount := 0 + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + callCount++ + if callCount == 1 { + return nil, fmt.Errorf("code: 10008, message: 任务已存在") + } + return []string{firstHash}, nil + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{ + Tasks: []sdk.OfflineTask{ + {InfoHash: staleHash, URL: testURL}, + }, + }, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + deleteCount++ + if infoHash != staleHash { + t.Fatalf("unexpected hash: %s", infoHash) + } + if deleteFiles { + t.Fatal("deleteFiles should be false") + } + return nil + }, + } + + hashes, err := addOfflineDownloadTask(context.Background(), client, testURL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(hashes) != 1 || hashes[0] != firstHash { + t.Fatalf("unexpected hashes: %+v", hashes) + } + if callCount != 2 { + t.Fatalf("want 2 download attempts, got %d", callCount) + } + if deleteCount != 2 { + t.Fatalf("want 2 delete attempts (pre-add + duplicate), got %d", deleteCount) + } + if listCount < 2 { + t.Fatalf("want at least 2 offline list calls, got %d", listCount) + } + }) + + t.Run("delete duplicate magnet and retry", func(t *testing.T) { + t.Parallel() + + callCount := 0 + deleteCount := 0 + listCount := 0 + magnetURL := "magnet:?xt=urn:btih:1234567890ABCDEF1234567890ABCDEF12345678&dn=test" + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + callCount++ + if callCount == 1 { + return nil, fmt.Errorf("code: 10008, message: 任务已存在") + } + return []string{firstHash}, nil + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{ + Tasks: []sdk.OfflineTask{ + {InfoHash: staleHash, URL: magnetURL}, + }, + }, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + deleteCount++ + if infoHash != staleHash { + t.Fatalf("unexpected hash: %s", infoHash) + } + if deleteFiles { + t.Fatal("deleteFiles should be false") + } + return nil + }, + } + + hashes, err := addOfflineDownloadTask(context.Background(), client, magnetURL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(hashes) != 1 || hashes[0] != firstHash { + t.Fatalf("unexpected hashes: %+v", hashes) + } + if callCount != 2 { + t.Fatalf("want 2 download attempts, got %d", callCount) + } + if deleteCount != 2 { + t.Fatalf("want 2 delete attempts (pre-add + duplicate), got %d", deleteCount) + } + if listCount < 2 { + t.Fatalf("want at least 2 offline list calls, got %d", listCount) + } + }) + + t.Run("delete duplicate and retry with decoded ed2k url", func(t *testing.T) { + t.Parallel() + + callCount := 0 + deleteCount := 0 + listCount := 0 + decodedURL := "ed2k://|file|[AVS]Azumi Mizushima [ネオパンストフェティッシュ Ver.19 水嶋あずみ](NOP-019)(2011.01.13).avi|1593601796|9E5CCC55541BD46EE8252BF100EFC46D|/" + encodedURL := "ed2k://|file|[AVS]Azumi%20Mizushima%20[ネオパンストフェティッシュ%20Ver.19%20水嶋あずみ](NOP-019)(2011.01.13).avi|1593601796|9E5CCC55541BD46EE8252BF100EFC46D|/" + + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + callCount++ + if callCount == 1 { + return nil, fmt.Errorf("code: 10008, message: 任务已存在,请勿输入重复的链接地址") + } + return []string{firstHash}, nil + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{ + Tasks: []sdk.OfflineTask{ + {InfoHash: staleHash, URL: decodedURL}, + }, + }, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + deleteCount++ + if infoHash != staleHash { + t.Fatalf("unexpected hash: %s", infoHash) + } + if deleteFiles { + t.Fatal("deleteFiles should be false") + } + return nil + }, + } + + hashes, err := addOfflineDownloadTask(context.Background(), client, encodedURL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(hashes) != 1 || hashes[0] != firstHash { + t.Fatalf("unexpected hashes: %+v", hashes) + } + if callCount != 2 { + t.Fatalf("want 2 download attempts, got %d", callCount) + } + if deleteCount != 2 { + t.Fatalf("want 2 delete attempts (pre-add + duplicate), got %d", deleteCount) + } + if listCount < 2 { + t.Fatalf("want at least 2 offline list calls, got %d", listCount) + } + }) + + t.Run("delete duplicate and retry with empty task url but matching name and size", func(t *testing.T) { + t.Parallel() + + callCount := 0 + deleteCount := 0 + listCount := 0 + encodedURL := "ed2k://|file|[AVS]Azumi%20Mizushima%20[ネオパンストフェティッシュ%20Ver.19%20水嶋あずみ](NOP-019)(2011.01.13).avi|1593601796|9E5CCC55541BD46EE8252BF100EFC46D|/" + + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + callCount++ + if callCount == 1 { + return nil, fmt.Errorf("code: 10008, message: 任务已存在,请勿输入重复的链接地址") + } + return []string{firstHash}, nil + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{ + Tasks: []sdk.OfflineTask{ + { + InfoHash: staleHash, + Name: "[AVS]Azumi Mizushima [ネオパンストフェティッシュ Ver.19 水嶋あずみ](NOP-019)(2011.01.13).avi", + Size: 1593601796, + URL: "", + }, + }, + }, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + deleteCount++ + if infoHash != staleHash { + t.Fatalf("unexpected hash: %s", infoHash) + } + if deleteFiles { + t.Fatal("deleteFiles should be false") + } + return nil + }, + } + + hashes, err := addOfflineDownloadTask(context.Background(), client, encodedURL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(hashes) != 1 || hashes[0] != firstHash { + t.Fatalf("unexpected hashes: %+v", hashes) + } + if callCount != 2 { + t.Fatalf("want 2 download attempts, got %d", callCount) + } + if deleteCount != 2 { + t.Fatalf("want 2 delete attempts (pre-add + duplicate), got %d", deleteCount) + } + if listCount < 2 { + t.Fatalf("want at least 2 offline list calls, got %d", listCount) + } + }) + + t.Run("delete duplicate directly from add response info hash", func(t *testing.T) { + t.Parallel() + + callCount := 0 + deleteCount := 0 + listCount := 0 + client := &mockOfflineTaskClient{ + offlineDownloadWithDetailsFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, []sdk.AddOfflineTaskURIsResp, string, error) { + callCount++ + if callCount == 1 { + return nil, []sdk.AddOfflineTaskURIsResp{ + {InfoHash: staleHash, URL: testURL}, + }, `{"state":false,"code":10008,"message":"任务已存在","data":[{"info_hash":"hash-stale","url":"` + testURL + `"}]}`, fmt.Errorf("code: 10008, message: 任务已存在") + } + return []string{firstHash}, nil, "", nil + }, + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + return nil, fmt.Errorf("unexpected fallback OfflineDownload call") + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{Tasks: nil}, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + deleteCount++ + if infoHash != staleHash { + t.Fatalf("unexpected hash: %s", infoHash) + } + if deleteFiles { + t.Fatal("deleteFiles should be false") + } + return nil + }, + } + + hashes, err := addOfflineDownloadTask(context.Background(), client, testURL, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(hashes) != 1 || hashes[0] != firstHash { + t.Fatalf("unexpected hashes: %+v", hashes) + } + if callCount != 2 { + t.Fatalf("want 2 download attempts, got %d", callCount) + } + if deleteCount != 1 { + t.Fatalf("want 1 delete attempt, got %d", deleteCount) + } + if listCount < 1 { + t.Fatalf("want pre-add offline list call, got %d", listCount) + } + }) + + t.Run("duplicate delete failure", func(t *testing.T) { + t.Parallel() + + listCount := 0 + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + return nil, fmt.Errorf("duplicate task") + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{ + Tasks: []sdk.OfflineTask{ + {InfoHash: staleHash, URL: testURL}, + }, + }, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + return fmt.Errorf(deleteError) + }, + } + + _, err := addOfflineDownloadTask(context.Background(), client, testURL, nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), deleteError) { + t.Fatalf("unexpected error: %v", err) + } + if listCount < 1 { + t.Fatalf("want pre-add offline list call, got %d", listCount) + } + }) + + t.Run("non duplicate error is returned", func(t *testing.T) { + t.Parallel() + + listCount := 0 + client := &mockOfflineTaskClient{ + offlineDownloadFunc: func(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + return nil, fmt.Errorf("network timeout") + }, + offlineListFunc: func(ctx context.Context) (*sdk.OfflineTaskListResp, error) { + listCount++ + return &sdk.OfflineTaskListResp{Tasks: nil}, nil + }, + deleteOfflineFunc: func(ctx context.Context, infoHash string, deleteFiles bool) error { + t.Fatal("DeleteOfflineTask should not be called") + return nil + }, + } + + _, err := addOfflineDownloadTask(context.Background(), client, testURL, nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "network timeout") { + t.Fatalf("unexpected error: %v", err) + } + if listCount < 1 { + t.Fatalf("want pre-add offline list call, got %d", listCount) + } + }) +} + +func TestOpen115BasicMethods(t *testing.T) { + t.Parallel() + + o := &Open115{} + + if o.Name() != "115 Open" { + t.Fatalf("unexpected name: %s", o.Name()) + } + if o.Items() != nil { + t.Fatal("Items should return nil") + } + msg, err := o.Init() + if err != nil { + t.Fatalf("unexpected init error: %v", err) + } + if msg != "ok" { + t.Fatalf("unexpected init message: %s", msg) + } +} diff --git a/internal/offline_download/aria2/aria2.go b/internal/offline_download/aria2/aria2.go index b04435aca..5c037ff4c 100644 --- a/internal/offline_download/aria2/aria2.go +++ b/internal/offline_download/aria2/aria2.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strconv" + "strings" "time" "github.com/OpenListTeam/OpenList/v4/internal/errs" @@ -61,6 +62,10 @@ func (a *Aria2) IsReady() bool { } func (a *Aria2) AddURL(args *tool.AddUrlArgs) (string, error) { + // aria2 不支持 ed2k 协议,提前检测并返回明确错误 + if strings.HasPrefix(strings.ToLower(args.Url), "ed2k://") { + return "", fmt.Errorf("aria2 does not support ed2k protocol. Please use Thunder/ThunderX/ThunderBrowser tool for ed2k links") + } options := map[string]interface{}{ "dir": args.TempDir, } diff --git a/internal/offline_download/pikpak/pikpak.go b/internal/offline_download/pikpak/pikpak.go index f48ed9951..bcd44720c 100644 --- a/internal/offline_download/pikpak/pikpak.go +++ b/internal/offline_download/pikpak/pikpak.go @@ -129,7 +129,7 @@ func (p *PikPak) Status(task *tool.DownloadTask) (*tool.Status, error) { s.TotalBytes = 0 } if t.Phase == "PHASE_TYPE_ERROR" { - s.Err = fmt.Errorf(t.Message) + s.Err = fmt.Errorf("%s", t.Message) } return s, nil } diff --git a/internal/offline_download/tool/add.go b/internal/offline_download/tool/add.go index 0f574571e..33128ccc2 100644 --- a/internal/offline_download/tool/add.go +++ b/internal/offline_download/tool/add.go @@ -2,9 +2,11 @@ package tool import ( "context" + "fmt" "net/url" stdpath "path" "path/filepath" + "strings" _115 "github.com/OpenListTeam/OpenList/v4/drivers/115" _115_open "github.com/OpenListTeam/OpenList/v4/drivers/115_open" @@ -67,10 +69,28 @@ func AddURL(ctx context.Context, args *AddURLArgs) (task.TaskExtensionInfo, erro } // try putting url if args.Tool == "SimpleHttp" { + if isSimpleHttpSchemeUnsupported(args.URL) { + return nil, fmt.Errorf("SimpleHttp tool does not support this URL scheme, please use aria2 or other tools for magnet/ed2k links") + } err = tryPutUrl(ctx, args.DstDirPath, args.URL) if err == nil || !errors.Is(err, errs.NotImplement) { return nil, err } + // Fallback to creating a download task when storage lacks native PutURL support. + } + + // ed2k 链接自动路由:如果当前工具不支持 ed2k,自动尝试使用迅雷系工具 + if isEd2kURL(args.URL) { + if !isEd2kCapableTool(args.Tool) { + // 尝试找到一个可用的支持 ed2k 的工具 + fallbackTool, fallbackName := findEd2kCapableTool() + if fallbackTool != nil { + // 使用找到的迅雷工具替代 + args.Tool = fallbackName + } else { + return nil, fmt.Errorf("ed2k protocol is not supported by %s. Please configure and use Thunder/ThunderX/ThunderBrowser for ed2k links", args.Tool) + } + } } // get tool @@ -171,3 +191,48 @@ func tryPutUrl(ctx context.Context, path, urlStr string) error { } return fs.PutURL(ctx, path, dstName, urlStr) } + +func isSimpleHttpSchemeUnsupported(urlStr string) bool { + u, err := url.Parse(strings.TrimSpace(urlStr)) + if err != nil || u.Scheme == "" { + return false + } + scheme := strings.ToLower(u.Scheme) + return scheme != "http" && scheme != "https" +} + +// isEd2kURL 检测 URL 是否为 ed2k 协议 +func isEd2kURL(urlStr string) bool { + return strings.HasPrefix(strings.ToLower(urlStr), "ed2k://") +} + +// ed2kCapableTools 支持 ed2k 协议的工具列表(迅雷系) +var ed2kCapableTools = []string{"Thunder", "ThunderX", "ThunderBrowser"} + +// isEd2kCapableTool 检查工具是否支持 ed2k 协议 +func isEd2kCapableTool(toolName string) bool { + for _, t := range ed2kCapableTools { + if t == toolName { + return true + } + } + return false +} + +// findEd2kCapableTool 查找一个可用的支持 ed2k 的工具 +func findEd2kCapableTool() (Tool, string) { + for _, name := range ed2kCapableTools { + t, err := Tools.Get(name) + if err != nil { + continue + } + if t.IsReady() { + return t, name + } + // 尝试初始化 + if _, err := t.Init(); err == nil && t.IsReady() { + return t, name + } + } + return nil, "" +} diff --git a/internal/offline_download/tool/download.go b/internal/offline_download/tool/download.go index 5ee6ef4ff..477c26a51 100644 --- a/internal/offline_download/tool/download.go +++ b/internal/offline_download/tool/download.go @@ -32,6 +32,8 @@ type DownloadTask struct { callStatusRetried int } +var completedOfflineTaskCleanupDelay = time.Second + func (t *DownloadTask) Run() error { t.ClearEndTime() t.SetStartTime(time.Now()) @@ -97,16 +99,7 @@ outer: if t.tool.Name() == "ThunderX" { return nil } - if t.tool.Name() == "115 Cloud" { - // hack for 115 - <-time.After(time.Second * 1) - err := t.tool.Remove(t) - if err != nil { - log.Errorln(err.Error()) - } - return nil - } - if t.tool.Name() == "115 Open" { + if t.tool.Name() == "115 Cloud" || t.tool.Name() == "115 Open" { return nil } if t.tool.Name() == "123 Open" { @@ -147,7 +140,7 @@ func (t *DownloadTask) Update() (bool, error) { if err != nil { t.callStatusRetried++ log.Errorf("failed to get status of %s, retried %d times", t.ID, t.callStatusRetried) - if t.callStatusRetried > 5 { + if t.callStatusRetried > 10 { return true, errors.Errorf("failed to get status of %s, retried %d times", t.ID, t.callStatusRetried) } return false, nil @@ -163,6 +156,13 @@ func (t *DownloadTask) Update() (bool, error) { } // if download completed if info.Completed { + // For 115, remove offline task record before transfer so it gets cleaned up even if transfer fails + if t.tool.Name() == "115 Cloud" || t.tool.Name() == "115 Open" { + <-time.After(completedOfflineTaskCleanupDelay) + if removeErr := t.tool.Remove(t); removeErr != nil { + log.Errorln(removeErr.Error()) + } + } err := t.Transfer() return true, errors.WithMessage(err, "failed to transfer file") } diff --git a/internal/offline_download/tool/download_test.go b/internal/offline_download/tool/download_test.go new file mode 100644 index 000000000..5303d35aa --- /dev/null +++ b/internal/offline_download/tool/download_test.go @@ -0,0 +1,90 @@ +package tool + +import ( + "context" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + task2 "github.com/OpenListTeam/OpenList/v4/internal/task" +) + +type mockTool struct { + name string + addURLFunc func(args *AddUrlArgs) (string, error) + removeFunc func(task *DownloadTask) error + statusFunc func(task *DownloadTask) (*Status, error) + runFunc func(task *DownloadTask) error +} + +func (m *mockTool) Name() string { return m.name } + +func (m *mockTool) Items() []model.SettingItem { return nil } + +func (m *mockTool) Init() (string, error) { return "ok", nil } + +func (m *mockTool) IsReady() bool { return true } + +func (m *mockTool) AddURL(args *AddUrlArgs) (string, error) { + return m.addURLFunc(args) +} + +func (m *mockTool) Remove(task *DownloadTask) error { + return m.removeFunc(task) +} + +func (m *mockTool) Status(task *DownloadTask) (*Status, error) { + return m.statusFunc(task) +} + +func (m *mockTool) Run(task *DownloadTask) error { + return m.runFunc(task) +} + +func TestDownloadTaskRun_RemovesCompleted115OpenRecord(t *testing.T) { + previousDelay := completedOfflineTaskCleanupDelay + completedOfflineTaskCleanupDelay = 0 + defer func() { + completedOfflineTaskCleanupDelay = previousDelay + }() + + removeCount := 0 + tool := &mockTool{ + name: "115 Open", + addURLFunc: func(args *AddUrlArgs) (string, error) { + return "gid-1", nil + }, + removeFunc: func(task *DownloadTask) error { + removeCount++ + if task.GID != "gid-1" { + t.Fatalf("unexpected gid: %s", task.GID) + } + return nil + }, + statusFunc: func(task *DownloadTask) (*Status, error) { + return &Status{ + Completed: true, + Status: "completed", + }, nil + }, + runFunc: func(task *DownloadTask) error { + return errs.NotSupport + }, + } + + task := &DownloadTask{ + TaskExtension: task2.TaskExtension{}, + Url: "https://example.com/test.torrent", + DstDirPath: "/115", + TempDir: "/115", + tool: tool, + } + task.SetCtx(context.Background()) + + if err := task.Run(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if removeCount != 1 { + t.Fatalf("want 1 cleanup remove, got %d", removeCount) + } +} diff --git a/internal/offline_download/tool/transfer.go b/internal/offline_download/tool/transfer.go index fd6b8f464..7109669ee 100644 --- a/internal/offline_download/tool/transfer.go +++ b/internal/offline_download/tool/transfer.go @@ -9,6 +9,7 @@ import ( "path/filepath" "time" + _189pc "github.com/OpenListTeam/OpenList/v4/drivers/189pc" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/fs" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -17,6 +18,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/task" "github.com/OpenListTeam/OpenList/v4/internal/task_group" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/pkg/torrent" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/OpenListTeam/tache" @@ -201,8 +203,28 @@ func transferStdFile(t *TransferTask) error { } info, err := rc.Stat() if err != nil { + rc.Close() return errors.Wrapf(err, "failed to get file %s", t.SrcActualPath) } + + // 尝试对天翼云进行秒传(计算 MD5 + sliceMD5) + if rapidObj, rapidErr := tryRapidUpload189(t, rc, info.Size()); rapidErr == nil && rapidObj != nil { + rc.Close() + log.Infof("秒传成功: %s -> %s", t.SrcActualPath, t.DstStorageMp) + return nil + } + + // 秒传失败或不支持,回退到普通上传 + // 重新 seek 到文件开头 + if _, err := rc.Seek(0, 0); err != nil { + rc.Close() + // 重新打开文件 + rc, err = os.Open(t.SrcActualPath) + if err != nil { + return errors.Wrapf(err, "failed to reopen file %s", t.SrcActualPath) + } + } + mimetype := utils.GetMimeType(t.SrcActualPath) s := &stream.FileStream{ Ctx: t.Ctx(), @@ -217,7 +239,12 @@ func transferStdFile(t *TransferTask) error { Closers: utils.NewClosers(rc), } t.SetTotalBytes(info.Size()) - return op.Put(context.WithValue(t.Ctx(), conf.SkipHookKey, struct{}{}), t.DstStorage, t.DstActualPath, s, t.SetProgress) + err = op.Put(context.WithValue(t.Ctx(), conf.SkipHookKey, struct{}{}), t.DstStorage, t.DstActualPath, s, t.SetProgress) + if err != nil { + return err + } + + return nil } func removeStdTemp(t *TransferTask) { @@ -341,3 +368,47 @@ func removeObjTemp(t *TransferTask) { log.Errorf("failed to delete temp obj %s, error: %s", t.SrcActualPath, err.Error()) } } + +// tryRapidUpload189 尝试对天翼云进行秒传 +// 通过计算文件的 MD5 来尝试秒传(使用旧版接口) +// 返回上传成功的对象和错误,如果不支持秒传则返回 nil, error +func tryRapidUpload189(t *TransferTask, file *os.File, fileSize int64) (model.Obj, error) { + // 检查目标存储是否是天翼云 PC 驱动 + cloud189PC, ok := t.DstStorage.(*_189pc.Cloud189PC) + if !ok { + return nil, fmt.Errorf("not 189pc storage") + } + + // 计算整文件 MD5(旧接口只需要 fileMD5) + fileMD5, _, err := _189pc.ComputeSliceMD5sFromReader(file, torrent.DefaultPieceSize) + if err != nil { + return nil, fmt.Errorf("计算 MD5 失败: %w", err) + } + + // 获取目标目录 + dstDir, err := op.Get(t.Ctx(), t.DstStorage, t.DstActualPath) + if err != nil { + return nil, fmt.Errorf("获取目标目录失败: %w", err) + } + + // 构造文件名 + fileName := filepath.Base(t.SrcActualPath) + + // 尝试秒传(使用旧接口) + uploadInfo, err := cloud189PC.OldUploadCreate(t.Ctx(), dstDir.GetID(), fileMD5, fileName, fmt.Sprint(fileSize), false) + if err != nil { + return nil, fmt.Errorf("创建上传任务失败: %w", err) + } + + if uploadInfo.FileDataExists != 1 { + return nil, fmt.Errorf("秒传失败:云端不存在该文件") + } + + // 秒传成功,提交 + obj, err := cloud189PC.OldUploadCommit(t.Ctx(), uploadInfo.FileCommitUrl, uploadInfo.UploadFileId, false, true) + if err != nil { + return nil, fmt.Errorf("提交上传失败: %w", err) + } + + return obj, nil +} diff --git a/internal/op/fs.go b/internal/op/fs.go index f7b1c45b5..09a26dcfe 100644 --- a/internal/op/fs.go +++ b/internal/op/fs.go @@ -149,6 +149,7 @@ func Get(ctx context.Context, storage driver.Driver, path string, excludeTempObj Modified: storage.GetStorage().Modified, IsFolder: true, Mask: model.Locked, + HashInfo: utils.NewHashInfo(nil, ""), }, nil case driver.IRootPath: return &model.Object{ @@ -157,6 +158,7 @@ func Get(ctx context.Context, storage driver.Driver, path string, excludeTempObj Modified: storage.GetStorage().Modified, Mask: model.Locked, IsFolder: true, + HashInfo: utils.NewHashInfo(nil, ""), }, nil } return nil, errors.New("please implement GetRooter or IRootPath or IRootId interface") @@ -246,6 +248,8 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li ol.link.SyncClosers.AcquireReference() || !ol.link.RequireReference { return ol.link, ol.obj, nil } + // SyncClosers 已关闭(文件句柄已关闭),删除缓存条目,重新获取 + Cache.linkCache.DeleteKey(key) } fn := func() (*objWithLink, error) { @@ -261,11 +265,37 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li if err != nil { return nil, errors.Wrapf(err, "failed get link") } + + // Set up link refresher for automatic refresh on expiry during long downloads + // This enables all download scenarios to handle link expiration gracefully + if link.Refresher == nil { + storageCopy := storage + pathCopy := path + argsCopy := args + link.Refresher = func(refreshCtx context.Context) (*model.Link, model.Obj, error) { + log.Infof("Refreshing download link for: %s", pathCopy) + // Get fresh link directly from storage, bypassing cache + file, err := GetUnwrap(refreshCtx, storageCopy, pathCopy) + if err != nil { + return nil, nil, errors.WithMessage(err, "failed to get file for refresh") + } + newLink, err := storageCopy.Link(refreshCtx, file, argsCopy) + if err != nil { + return nil, nil, errors.Wrapf(err, "failed to refresh link") + } + return newLink, file, nil + } + } + ol := &objWithLink{link: link, obj: file} if link.Expiration != nil { Cache.linkCache.SetTypeWithTTL(key, typeKey, ol, *link.Expiration) - } else { + } else if link.RequireReference { + // 本地文件等需要引用计数的链接,缓存与文件句柄生命周期绑定 Cache.linkCache.SetTypeWithExpirable(key, typeKey, ol, &link.SyncClosers) + } else { + // 不需要引用计数(如云盘链接无过期时间),使用默认 TTL,多客户端复用 + Cache.linkCache.SetType(key, typeKey, ol) } return ol, nil } @@ -325,9 +355,16 @@ func MakeDir(ctx context.Context, storage driver.Driver, path string) error { return nil, errors.WithMessagef(err, "failed to make parent dir [%s]", parentPath) } parentDir, err := GetUnwrap(ctx, storage, parentPath) - // this should not happen if err != nil { - return nil, errors.WithMessagef(err, "failed to get parent dir [%s]", parentPath) + if errs.IsObjectNotFound(err) { + // Retry once after a short delay (handles cloud storage API sync delay) + log.Debugf("[op] parent dir [%s] not found immediately after creation, retrying...", parentPath) + time.Sleep(100 * time.Millisecond) + parentDir, err = GetUnwrap(ctx, storage, parentPath) + } + if err != nil { + return nil, errors.WithMessagef(err, "failed to get parent dir [%s]", parentPath) + } } if !parentDir.IsDir() { return nil, errs.NotFolder @@ -360,6 +397,7 @@ func MakeDir(ctx context.Context, storage driver.Driver, path string) error { Modified: t, Ctime: t, Mask: model.Temp, + HashInfo: utils.NewHashInfo(nil, ""), } } dirCache.UpdateObject("", wrapObjName(storage, newObj)) @@ -684,6 +722,7 @@ func Put(ctx context.Context, storage driver.Driver, dstDirPath string, file mod Modified: file.ModTime(), Ctime: file.CreateTime(), Mask: model.Temp, + HashInfo: utils.NewHashInfo(nil, ""), } } newObj = wrapObjName(storage, newObj) @@ -752,6 +791,7 @@ func PutURL(ctx context.Context, storage driver.Driver, dstDirPath, dstName, url Modified: t, Ctime: t, Mask: model.Temp, + HashInfo: utils.NewHashInfo(nil, ""), } } newObj = wrapObjName(storage, newObj) diff --git a/internal/op/storage.go b/internal/op/storage.go index da4c84e31..2e93bf569 100644 --- a/internal/op/storage.go +++ b/internal/op/storage.go @@ -368,7 +368,9 @@ func GetStorageVirtualFilesWithDetailsByPath(ctx context.Context, prefix string, }(d) select { case r := <-resultChan: - ret.StorageDetails = r + if r != nil { + ret.StorageDetails = r + } case <-time.After(time.Second): } return ret @@ -419,6 +421,7 @@ func getStorageVirtualFilesByPath(prefix string, rootCallback func(driver.Driver Name: name, Modified: v.GetStorage().Modified, IsFolder: true, + HashInfo: utils.NewHashInfo(nil, ""), } if !found { idx := len(files) diff --git a/internal/stream/section_reader_prefetch_test.go b/internal/stream/section_reader_prefetch_test.go new file mode 100644 index 000000000..57e14db66 --- /dev/null +++ b/internal/stream/section_reader_prefetch_test.go @@ -0,0 +1,399 @@ +package stream + +import ( + "bytes" + "context" + "crypto/rand" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" +) + +// signalReader is an io.Reader that records the high-water mark of bytes +// read so tests can deterministically observe how much of the source has +// been consumed at a given moment. It also exposes an optional per-Read +// delay so the test can create a window in which prefetch can run. +type signalReader struct { + data []byte + readPos atomic.Int64 + delay time.Duration + readCnt atomic.Int32 + chunkSig chan int64 // optional: receives high-water mark after each Read +} + +func newSignalReader(data []byte, delay time.Duration) *signalReader { + return &signalReader{data: data, delay: delay, chunkSig: make(chan int64, 1024)} +} + +func (r *signalReader) Read(p []byte) (int, error) { + pos := r.readPos.Load() + if pos >= int64(len(r.data)) { + return 0, io.EOF + } + if r.delay > 0 { + time.Sleep(r.delay) + } + n := copy(p, r.data[pos:]) + newPos := pos + int64(n) + r.readPos.Store(newPos) + r.readCnt.Add(1) + select { + case r.chunkSig <- newPos: + default: + } + return n, nil +} + +// fakeFileStream wraps a signalReader into a FileStreamer-compatible value +// reusing the production FileStream type. +func newFakeFileStream(t *testing.T, data []byte, delay time.Duration) (*FileStream, *signalReader) { + t.Helper() + sr := newSignalReader(data, delay) + fs := &FileStream{ + Ctx: context.Background(), + Obj: &model.Object{Name: "test.bin", Size: int64(len(data))}, + Reader: io.NopCloser(sr), + } + return fs, sr +} + +// withStreamConf sets minimal stream/conf values needed for the +// hybridSectionReader path and restores them afterwards. +func withStreamConf(t *testing.T, cacheThreshold, maxBlock uint64) { + t.Helper() + prevCT := conf.AutoMemoryLimit + prevMB := conf.MaxBlockLimit + prevMF := conf.MinFreeMemory + prevConf := conf.Conf + t.Cleanup(func() { + conf.AutoMemoryLimit = prevCT + conf.MaxBlockLimit = prevMB + conf.MinFreeMemory = prevMF + conf.Conf = prevConf + }) + conf.AutoMemoryLimit = cacheThreshold + conf.MaxBlockLimit = maxBlock + conf.MinFreeMemory = 1 // keep memory path enabled + conf.Conf = &conf.Config{} +} + +// TestHybridSectionReader_PrefetchAdvancesSourceAhead verifies that after +// GetSectionReader returns block N, the underlying source has been read +// past the end of block N (i.e., prefetch of block N+1 is in flight or +// already complete). This is the core behavior of Pass 2 prefetch. +func TestHybridSectionReader_PrefetchAdvancesSourceAhead(t *testing.T) { + withStreamConf(t, 1024, 4*1024) + + const partSize = int64(4 * 1024) + data := make([]byte, partSize*4) + _, _ = rand.Read(data) + + // Per-Read delay so each chunk takes measurable time; this gives + // prefetch a window to advance during the "upload" phase below. + fs, sr := newFakeFileStream(t, data, 5*time.Millisecond) + + ss, err := NewStreamSectionReader(fs, int(partSize), nil) + if err != nil { + t.Fatalf("NewStreamSectionReader: %v", err) + } + hsr, ok := ss.(*hybridSectionReader) + if !ok { + t.Fatalf("expected *hybridSectionReader, got %T", ss) + } + _ = hsr + + // Read block 0 + rs, err := ss.GetSectionReader(0, partSize) + if err != nil { + t.Fatalf("GetSectionReader(0): %v", err) + } + got, err := io.ReadAll(rs) + if err != nil { + t.Fatalf("ReadAll block0: %v", err) + } + if !bytes.Equal(got, data[:partSize]) { + t.Fatalf("block0 data mismatch") + } + + // Simulate "upload" time; prefetch should pull block 1 from source. + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if sr.readPos.Load() > partSize { + break + } + time.Sleep(2 * time.Millisecond) + } + + if got := sr.readPos.Load(); got <= partSize { + t.Fatalf("expected prefetch to advance source past %d, got %d", partSize, got) + } + + ss.FreeSectionReader(rs) + + // Read block 1 — should be served from prefetch (data must still be correct). + rs2, err := ss.GetSectionReader(partSize, partSize) + if err != nil { + t.Fatalf("GetSectionReader(1): %v", err) + } + got2, err := io.ReadAll(rs2) + if err != nil { + t.Fatalf("ReadAll block1: %v", err) + } + if !bytes.Equal(got2, data[partSize:2*partSize]) { + t.Fatalf("block1 data mismatch") + } + ss.FreeSectionReader(rs2) +} + +// TestHybridSectionReader_PrefetchSequentialCorrectness verifies that a +// full sequential walk through the file returns the exact original bytes +// when prefetch is enabled. +func TestHybridSectionReader_PrefetchSequentialCorrectness(t *testing.T) { + withStreamConf(t, 1024, 4*1024) + + const partSize = int64(4 * 1024) + const partCount = 6 + data := make([]byte, partSize*partCount) + _, _ = rand.Read(data) + + fs, _ := newFakeFileStream(t, data, 0) + ss, err := NewStreamSectionReader(fs, int(partSize), nil) + if err != nil { + t.Fatalf("NewStreamSectionReader: %v", err) + } + + for i := int64(0); i < partCount; i++ { + off := i * partSize + rs, err := ss.GetSectionReader(off, partSize) + if err != nil { + t.Fatalf("GetSectionReader(%d): %v", i, err) + } + got, err := io.ReadAll(rs) + if err != nil { + t.Fatalf("ReadAll part %d: %v", i, err) + } + if !bytes.Equal(got, data[off:off+partSize]) { + t.Fatalf("part %d data mismatch", i) + } + ss.FreeSectionReader(rs) + } +} + +// TestHybridSectionReader_PrefetchLastChunkPartial covers the final +// partial chunk: when the file isn't a multiple of partSize, the last +// GetSectionReader call asks for a smaller length than the previous +// (prefetched) call expected. The data returned must still be correct. +func TestHybridSectionReader_PrefetchLastChunkPartial(t *testing.T) { + withStreamConf(t, 1024, 4*1024) + + const partSize = int64(4 * 1024) + // 2.5 parts: last chunk is partSize/2 + data := make([]byte, partSize*2+partSize/2) + _, _ = rand.Read(data) + + fs, _ := newFakeFileStream(t, data, 0) + ss, err := NewStreamSectionReader(fs, int(partSize), nil) + if err != nil { + t.Fatalf("NewStreamSectionReader: %v", err) + } + + offsets := []struct{ off, length int64 }{ + {0, partSize}, + {partSize, partSize}, + {partSize * 2, partSize / 2}, + } + for i, p := range offsets { + rs, err := ss.GetSectionReader(p.off, p.length) + if err != nil { + t.Fatalf("GetSectionReader(%d, %d): %v", p.off, p.length, err) + } + got, err := io.ReadAll(rs) + if err != nil { + t.Fatalf("ReadAll part %d: %v", i, err) + } + if !bytes.Equal(got, data[p.off:p.off+p.length]) { + t.Fatalf("part %d data mismatch", i) + } + ss.FreeSectionReader(rs) + } +} + +// TestHybridSectionReader_PrefetchErrorSurfaces verifies that a read +// error encountered during prefetch is reported on the next +// GetSectionReader call (rather than silently ignored). +func TestHybridSectionReader_PrefetchErrorSurfaces(t *testing.T) { + withStreamConf(t, 1024, 4*1024) + + const partSize = int64(4 * 1024) + data := make([]byte, partSize*3) + _, _ = rand.Read(data) + + // Wrap the signal reader so reads targeting the second chunk return + // a hard error and zero bytes — simulating mid-chunk download failure. + sr := newSignalReader(data, 0) + fs := &FileStream{ + Ctx: context.Background(), + Obj: &model.Object{Name: "test.bin", Size: int64(len(data))}, + Reader: io.NopCloser(readerFunc(func(p []byte) (int, error) { + if sr.readPos.Load() >= partSize { + return 0, io.ErrUnexpectedEOF + } + remaining := partSize - sr.readPos.Load() + if int64(len(p)) > remaining { + p = p[:remaining] + } + return sr.Read(p) + })), + } + + ss, err := NewStreamSectionReader(fs, int(partSize), nil) + if err != nil { + t.Fatalf("NewStreamSectionReader: %v", err) + } + + rs, err := ss.GetSectionReader(0, partSize) + if err != nil { + t.Fatalf("GetSectionReader(0): %v", err) + } + _, _ = io.ReadAll(rs) + ss.FreeSectionReader(rs) + + // Wait for prefetch to complete (and presumably fail). + time.Sleep(20 * time.Millisecond) + + _, err = ss.GetSectionReader(partSize, partSize) + if err == nil { + t.Fatalf("expected error from prefetch failure, got nil") + } +} + +// TestHybridSectionReader_NoPrefetchAfterEOF makes sure we don't start a +// prefetch goroutine past the end of the file. +func TestHybridSectionReader_NoPrefetchAfterEOF(t *testing.T) { + withStreamConf(t, 1024, 4*1024) + + const partSize = int64(4 * 1024) + data := make([]byte, partSize*2) + _, _ = rand.Read(data) + + fs, sr := newFakeFileStream(t, data, 0) + ss, err := NewStreamSectionReader(fs, int(partSize), nil) + if err != nil { + t.Fatalf("NewStreamSectionReader: %v", err) + } + + for i := int64(0); i < 2; i++ { + rs, err := ss.GetSectionReader(i*partSize, partSize) + if err != nil { + t.Fatalf("GetSectionReader(%d): %v", i, err) + } + _, _ = io.ReadAll(rs) + ss.FreeSectionReader(rs) + } + + // Allow any (incorrect) prefetch to fire. + time.Sleep(20 * time.Millisecond) + + // Source must not have been read past end of file. + if got := sr.readPos.Load(); got != int64(len(data)) { + t.Fatalf("source read pos = %d, want %d (no over-read past EOF)", got, len(data)) + } + + hsr := ss.(*hybridSectionReader) + if hsr.prefetch != nil { + t.Fatalf("expected no in-flight prefetch after final chunk") + } +} + +// TestHybridSectionReader_PrefetchClampedToFileSize ensures the prefetch +// length is clamped to the remaining bytes so the goroutine doesn't +// request more than the file holds. +func TestHybridSectionReader_PrefetchClampedToFileSize(t *testing.T) { + withStreamConf(t, 1024, 4*1024) + + const partSize = int64(4 * 1024) + // 1.5 parts so prefetch after block 0 must be clamped to partSize/2. + data := make([]byte, partSize+partSize/2) + _, _ = rand.Read(data) + + fs, _ := newFakeFileStream(t, data, 0) + ss, err := NewStreamSectionReader(fs, int(partSize), nil) + if err != nil { + t.Fatalf("NewStreamSectionReader: %v", err) + } + + rs, err := ss.GetSectionReader(0, partSize) + if err != nil { + t.Fatalf("GetSectionReader(0): %v", err) + } + _, _ = io.ReadAll(rs) + ss.FreeSectionReader(rs) + + rs2, err := ss.GetSectionReader(partSize, partSize/2) + if err != nil { + t.Fatalf("GetSectionReader(1): %v", err) + } + got, err := io.ReadAll(rs2) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(got, data[partSize:]) { + t.Fatalf("partial-tail data mismatch") + } + ss.FreeSectionReader(rs2) +} + +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { return f(p) } + +// TestHybridSectionReader_PrefetchHonorsDiscardSection verifies that +// DiscardSection invalidates any pending prefetch and that subsequent +// GetSectionReader calls at the new offset return correct data. +func TestHybridSectionReader_PrefetchHonorsDiscardSection(t *testing.T) { + withStreamConf(t, 1024, 4*1024) + + const partSize = int64(4 * 1024) + data := make([]byte, partSize*4) + _, _ = rand.Read(data) + + fs, _ := newFakeFileStream(t, data, 0) + ss, err := NewStreamSectionReader(fs, int(partSize), nil) + if err != nil { + t.Fatalf("NewStreamSectionReader: %v", err) + } + + // Read block 0 (kicks off prefetch of block 1) + rs, err := ss.GetSectionReader(0, partSize) + if err != nil { + t.Fatalf("GetSectionReader: %v", err) + } + _, _ = io.ReadAll(rs) + ss.FreeSectionReader(rs) + + // Give prefetch a moment to complete + time.Sleep(20 * time.Millisecond) + + // Caller decides to skip block 1 — calls DiscardSection + if err := ss.DiscardSection(partSize, partSize); err != nil { + t.Fatalf("DiscardSection: %v", err) + } + + // Now read block 2 — must return correct data + rs2, err := ss.GetSectionReader(2*partSize, partSize) + if err != nil { + t.Fatalf("GetSectionReader after discard: %v", err) + } + got, err := io.ReadAll(rs2) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(got, data[2*partSize:3*partSize]) { + t.Fatalf("block 2 data mismatch after discard") + } + ss.FreeSectionReader(rs2) +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 4c8238100..d7a0936ac 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -1,6 +1,7 @@ package stream import ( + "bytes" "context" "errors" "fmt" @@ -10,11 +11,11 @@ import ( "sync" "github.com/OpenListTeam/OpenList/v4/internal/conf" + hcache "github.com/OpenListTeam/OpenList/v4/internal/hybrid_cache" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/rclone/rclone/lib/mmap" "go4.org/readerutil" ) @@ -28,8 +29,9 @@ type FileStream struct { Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it utils.Closers size int64 - peekBuff *buffer.Reader oriReader io.Reader // the original reader, used for caching + hc *hcache.HybridCache + peek buffer.SizedReadAtSeeker } func (f *FileStream) GetSize() int64 { @@ -51,15 +53,6 @@ func (f *FileStream) IsForceStreamUpload() bool { return f.ForceStreamUpload } -func (f *FileStream) Close() error { - if f.peekBuff != nil { - f.peekBuff.Reset() - f.oriReader = nil - f.peekBuff = nil - } - return f.Closers.Close() -} - func (f *FileStream) GetExist() model.Obj { return f.Exist } @@ -101,79 +94,57 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ } reader := f.Reader - if f.peekBuff != nil { - f.peekBuff.Seek(0, io.SeekStart) + if f.peek != nil { + f.peek.Seek(0, io.SeekStart) if writer != nil { - _, err := utils.CopyWithBuffer(writer, f.peekBuff) + _, err := utils.CopyWithBuffer(writer, f.peek) if err != nil { return nil, err } - f.peekBuff.Seek(0, io.SeekStart) + f.peek.Seek(0, io.SeekStart) } reader = f.oriReader } if writer != nil { reader = io.TeeReader(reader, writer) } + + // 如果文件大小未知,直接缓存到磁盘 if f.GetSize() < 0 { - if f.peekBuff == nil { - f.peekBuff = &buffer.Reader{} - } // 检查是否有数据 buf := []byte{0} n, err := io.ReadFull(reader, buf) - if n > 0 { - f.peekBuff.Append(buf[:n]) - } - if err == io.ErrUnexpectedEOF { - f.size = f.peekBuff.Size() - f.Reader = f.peekBuff - return f.peekBuff, nil + br := bytes.NewReader(buf[:n]) + if err == io.ErrUnexpectedEOF || err == io.EOF { + f.size = br.Size() + f.Reader = br + return br, nil } else if err != nil { return nil, err } - if conf.MaxBufferLimit-n > conf.MmapThreshold && conf.MmapThreshold > 0 { - m, err := mmap.Alloc(conf.MaxBufferLimit - n) - if err == nil { - f.Add(utils.CloseFunc(func() error { - return mmap.Free(m) - })) - n, err = io.ReadFull(reader, m) - if n > 0 { - f.peekBuff.Append(m[:n]) - } - if err == io.ErrUnexpectedEOF { - f.size = f.peekBuff.Size() - f.Reader = f.peekBuff - return f.peekBuff, nil - } else if err != nil { - return nil, err - } - } - } - tmpF, err := utils.CreateTempFile(reader, 0) + tmpF, err := utils.CreateTempFile(io.MultiReader(br, reader), 0) if err != nil { return nil, err } f.Add(utils.CloseFunc(func() error { return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name())) })) - peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF) + stat, err := tmpF.Stat() if err != nil { return nil, err } - f.size = peekF.Size() - f.Reader = peekF - return peekF, nil + f.size = stat.Size() + f.Reader = tmpF + return tmpF, nil } if up != nil { cacheProgress := model.UpdateProgressWithRange(*up, 0, 50) *up = model.UpdateProgressWithRange(*up, 50, 100) size := f.GetSize() - if f.peekBuff != nil { - peekSize := f.peekBuff.Size() - cacheProgress(float64(peekSize) / float64(size) * 100) + if f.peek != nil { + peekSize := f.peek.Size() + // cacheProgress(float64(peekSize) / float64(size) * 100) size -= peekSize } reader = &ReaderUpdatingProgress{ @@ -185,12 +156,12 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ } } - if f.peekBuff != nil { + if f.oriReader != nil { f.oriReader = reader } else { f.Reader = reader } - return f.cache(f.GetSize()) + return f.ensureCache(f.GetSize()) } func (f *FileStream) GetFile() model.File { @@ -211,7 +182,7 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { return io.NewSectionReader(f.GetFile(), httpRange.Start, httpRange.Length), nil } - cache, err := f.cache(httpRange.Start + httpRange.Length) + cache, err := f.ensureCache(httpRange.Start + httpRange.Length) if err != nil { return nil, err } @@ -224,64 +195,32 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 // 确保指定大小的数据被缓存 -func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { - if maxCacheSize > int64(conf.MaxBufferLimit) { - size := f.GetSize() - reader := f.Reader - if f.peekBuff != nil { - size -= f.peekBuff.Size() - reader = f.oriReader - } - tmpF, err := utils.CreateTempFile(reader, size) +func (f *FileStream) ensureCache(size int64) (model.File, error) { + if f.peek == nil { + blockSize := min(size, f.GetSize(), int64(conf.MaxBlockLimit)) + var err error + f.hc, err = hcache.NewHybridCache(uint64(blockSize), uint64(f.GetSize())) if err != nil { return nil, err } - f.Add(utils.CloseFunc(func() error { - return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name())) - })) - if f.peekBuff != nil { - peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF) - if err != nil { - return nil, err - } - f.Reader = peekF - return peekF, nil - } - f.Reader = tmpF - return tmpF, nil - } - - if f.peekBuff == nil { - f.peekBuff = &buffer.Reader{} + f.peek = buffer.NewDynamicReadAtSeeker(f.hc) f.oriReader = f.Reader - f.Reader = io.MultiReader(f.peekBuff, f.oriReader) + f.Reader = io.MultiReader(f.peek, f.oriReader) + f.Add(f.hc) } - bufSize := maxCacheSize - f.peekBuff.Size() - if bufSize <= 0 { - return f.peekBuff, nil + size = size - f.peek.Size() + if size <= 0 { + return f.peek, nil } - var buf []byte - if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) { - m, err := mmap.Alloc(int(bufSize)) - if err == nil { - f.Add(utils.CloseFunc(func() error { - return mmap.Free(m) - })) - buf = m - } - } - if buf == nil { - buf = make([]byte, bufSize) - } - n, err := io.ReadFull(f.oriReader, buf) - if bufSize != int64(n) { - return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err) + written, err := f.hc.CopyFromN(f.oriReader, size) + if written != size { + f.hc.RewindBySize(uint64(size - written)) + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", size, written, err) } - f.peekBuff.Append(buf) - if f.peekBuff.Size() >= f.GetSize() { - f.Reader = f.peekBuff + if f.peek.Size() >= f.GetSize() { + f.Reader = f.peek } - return f.peekBuff, nil + return f.peek, nil } var _ model.FileStreamer = (*SeekableStream)(nil) @@ -315,15 +254,9 @@ func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error if err != nil { return nil, err } - if _, ok := rr.(*model.FileRangeReader); ok { - var rc io.ReadCloser - rc, err = rr.RangeRead(fs.Ctx, http_range.Range{Length: -1}) - if err != nil { - return nil, err - } - fs.Reader = rc - fs.Add(rc) - } + // IMPORTANT: Do NOT create Reader early for FileRangeReader! + // Let generateReader() create it on-demand when actually needed for reading + // This prevents the Reader from being consumed by intermediate operations like hash calculation fs.size = size fs.Add(link) return &SeekableStream{FileStream: fs, rangeReader: rr}, nil diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index 9a81e7d41..1d8d002e2 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -1,4 +1,4 @@ -package stream +package stream_test import ( "bytes" @@ -7,27 +7,37 @@ import ( "io" "testing" + "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" ) -func TestFileStream_RangeRead(t *testing.T) { +func TestRangeRead(t *testing.T) { type args struct { httpRange http_range.Range } buf := []byte("github.com/OpenListTeam/OpenList") - f := &FileStream{ + f := &stream.FileStream{ Obj: &model.Object{ Size: int64(len(buf)), }, Reader: io.NopCloser(bytes.NewReader(buf)), } + prevAutoMemoryLimit := conf.AutoMemoryLimit + prevMaxBlockLimit := conf.MaxBlockLimit + t.Cleanup(func() { + conf.AutoMemoryLimit = prevAutoMemoryLimit + conf.MaxBlockLimit = prevMaxBlockLimit + }) + conf.AutoMemoryLimit = 0 + conf.MaxBlockLimit = 15 tests := []struct { name string - f *FileStream + f *stream.FileStream args args - want func(f *FileStream, got io.Reader, err error) error + want func(f *stream.FileStream, got io.Reader, err error) error }{ { name: "range 11-12", @@ -35,7 +45,7 @@ func TestFileStream_RangeRead(t *testing.T) { args: args{ httpRange: http_range.Range{Start: 11, Length: 12}, }, - want: func(f *FileStream, got io.Reader, err error) error { + want: func(f *stream.FileStream, got io.Reader, err error) error { if f.GetFile() != nil { return errors.New("cached") } @@ -52,7 +62,7 @@ func TestFileStream_RangeRead(t *testing.T) { args: args{ httpRange: http_range.Range{Start: 11, Length: 21}, }, - want: func(f *FileStream, got io.Reader, err error) error { + want: func(f *stream.FileStream, got io.Reader, err error) error { if f.GetFile() == nil { return errors.New("not cached") } @@ -84,14 +94,22 @@ func TestFileStream_RangeRead(t *testing.T) { } } -func TestFileStream_With_PreHash(t *testing.T) { +func TestPreHash(t *testing.T) { buf := []byte("github.com/OpenListTeam/OpenList") - f := &FileStream{ + f := &stream.FileStream{ Obj: &model.Object{ Size: int64(len(buf)), }, Reader: io.NopCloser(bytes.NewReader(buf)), } + prevAutoMemoryLimit := conf.AutoMemoryLimit + prevMaxBlockLimit := conf.MaxBlockLimit + t.Cleanup(func() { + conf.AutoMemoryLimit = prevAutoMemoryLimit + conf.MaxBlockLimit = prevMaxBlockLimit + }) + conf.AutoMemoryLimit = 0 + conf.MaxBlockLimit = 15 const hashSize int64 = 20 reader, _ := f.RangeRead(http_range.Range{Start: 0, Length: hashSize}) @@ -99,7 +117,7 @@ func TestFileStream_With_PreHash(t *testing.T) { if preHash == "" { t.Error("preHash is empty") } - tmpF, fullHash, _ := CacheFullAndHash(f, nil, utils.SHA1) + tmpF, fullHash, _ := stream.CacheFullAndHash(f, nil, utils.SHA1) fmt.Println(fullHash) fileFullHash, _ := utils.HashFile(utils.SHA1, tmpF) fmt.Println(fileFullHash) @@ -107,3 +125,61 @@ func TestFileStream_With_PreHash(t *testing.T) { t.Errorf("fullHash and fileFullHash should match: fullHash=%s fileFullHash=%s", fullHash, fileFullHash) } } + +func TestStreamSectionReader(t *testing.T) { + buf := make([]byte, 8<<10) + for i := range len(buf) { + buf[i] = byte(i % 256) + } + f := &stream.FileStream{ + Obj: &model.Object{ + Size: int64(len(buf)), + }, + Reader: io.NopCloser(bytes.NewReader(buf)), + } + prevAutoMemoryLimit := conf.AutoMemoryLimit + prevMaxBlockLimit := conf.MaxBlockLimit + prevConf := conf.Conf + t.Cleanup(func() { + conf.AutoMemoryLimit = prevAutoMemoryLimit + conf.MaxBlockLimit = prevMaxBlockLimit + conf.Conf = prevConf + }) + conf.AutoMemoryLimit = 0 + conf.MaxBlockLimit = 2 << 10 + partSize := 3 << 10 + conf.Conf = &conf.Config{} + ss, err := stream.NewStreamSectionReader(f, partSize, nil) + if err != nil { + t.Errorf("NewStreamSectionReader() error = %v", err) + } + for i := 0; i < len(buf); i += partSize { + length := partSize + if i+length > len(buf) { + length = len(buf) - i + } + rs, err := ss.GetSectionReader(int64(i), int64(length)) + if err != nil { + t.Errorf("StreamSectionReader.GetSectionReader() error = %v", err) + } + b1, err := io.ReadAll(rs) + if err != nil { + t.Errorf("StreamSectionReader.Read() error = %v", err) + } + rs.Seek(1, io.SeekStart) + b2, _ := io.ReadAll(rs) + if !bytes.Equal(b1[1:], b2) { + t.Errorf("StreamSectionReader.Read() = %s, want %s", b1[1:], b2) + } + if !bytes.Equal(buf[i:i+length], b1) { + t.Errorf("StreamSectionReader.Read() = %s, want %s", b1, buf[i:i+length]) + } + if i == 0 { + prevMinFreeMemory := conf.MinFreeMemory + conf.MinFreeMemory = 0 // 强制使用文件缓存 + t.Cleanup(func() { + conf.MinFreeMemory = prevMinFreeMemory + }) + } + } +} diff --git a/internal/stream/util.go b/internal/stream/util.go index 6aa3dda5d..e097b3cc3 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -1,33 +1,326 @@ package stream import ( - "bytes" "context" "encoding/hex" "errors" "fmt" "io" "net/http" - "os" + "strings" + "sync" + "time" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" + hcache "github.com/OpenListTeam/OpenList/v4/internal/hybrid_cache" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/net" + "github.com/OpenListTeam/OpenList/v4/pkg/buffer" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" - "github.com/OpenListTeam/OpenList/v4/pkg/pool" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/rclone/rclone/lib/mmap" log "github.com/sirupsen/logrus" ) +const ( + // 链接刷新相关常量 + MAX_LINK_REFRESH_COUNT = 50 // 下载链接最大刷新次数(支持长时间传输) + + // RangeRead 重试相关常量 + MAX_RANGE_READ_RETRY_COUNT = 5 // RangeRead 最大重试次数(从3增加到5) +) + type RangeReaderFunc func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { return f(ctx, httpRange) } +// IsLinkExpiredError checks if the error indicates an expired download link +func IsLinkExpiredError(err error) bool { + if err == nil { + return false + } + + // Don't treat context cancellation as link expiration + // This happens when user pauses/seeks video or cancels download + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + errStr := strings.ToLower(err.Error()) + + // Common expired link error keywords + expiredKeywords := []string{ + "expired", "invalid signature", "token expired", + "access denied", "forbidden", "unauthorized", + "link has expired", "url expired", "request has expired", + "signature expired", "accessdenied", "invalidtoken", + } + for _, keyword := range expiredKeywords { + if strings.Contains(errStr, keyword) { + return true + } + } + + // Check for HTTP status codes that typically indicate expired links + if statusErr, ok := errs.UnwrapOrSelf(err).(net.HttpStatusCodeError); ok { + code := int(statusErr) + // All 4xx client errors may indicate expired/invalid links + // 400 Bad Request, 401 Unauthorized, 403 Forbidden, 404 Not Found, 410 Gone, etc. + if code >= 400 && code < 500 { + return true + } + } + + return false +} + +// RefreshableRangeReader wraps a RangeReader with link refresh capability +type RefreshableRangeReader struct { + link *model.Link + size int64 + innerReader model.RangeReaderIF + mu sync.Mutex + refreshCount int // track refresh count to avoid infinite loops +} + +// NewRefreshableRangeReader creates a new RefreshableRangeReader +func NewRefreshableRangeReader(link *model.Link, size int64) *RefreshableRangeReader { + return &RefreshableRangeReader{ + link: link, + size: size, + } +} + +func (r *RefreshableRangeReader) getInnerReader() (model.RangeReaderIF, error) { + if r.innerReader != nil { + return r.innerReader, nil + } + + // Create inner reader without Refresher to avoid recursion + linkCopy := *r.link + linkCopy.Refresher = nil + + reader, err := GetRangeReaderFromLink(r.size, &linkCopy) + if err != nil { + return nil, err + } + r.innerReader = reader + return reader, nil +} + +// RangeRead obtains a reader reference under lock, then issues the range +// request outside the lock. The local copy of `reader` remains valid even +// if a concurrent refresh nils r.innerReader and replaces r.link, because +// doRefreshLocked only clears the pointer — it does not close or invalidate +// the old reader object. Each reader is an independent RangeReaderFunc +// (closure over a URL), so stale readers simply use the old URL; if it has +// expired, selfHealingReadCloser handles the retry transparently. +func (r *RefreshableRangeReader) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + r.mu.Lock() + reader, err := r.getInnerReader() + r.mu.Unlock() + if err != nil { + return nil, err + } + + rc, err := reader.RangeRead(ctx, httpRange) + if err != nil { + // Check if we should try to refresh on initial connection error + if IsLinkExpiredError(err) && r.link.Refresher != nil { + rc, err = r.refreshAndRetry(ctx, httpRange) + } + if err != nil { + return nil, err + } + } + + // Wrap the ReadCloser with self-healing capability to detect 0-byte reads + // This handles cases where cloud providers return 200 OK but empty body for expired links + return &selfHealingReadCloser{ + ReadCloser: rc, + refresher: r, + ctx: ctx, + httpRange: httpRange, + firstRead: false, + closed: false, + }, nil +} + +func (r *RefreshableRangeReader) refreshAndRetry(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if err := r.doRefreshLocked(ctx); err != nil { + return nil, err + } + + reader, err := r.getInnerReader() + if err != nil { + return nil, err + } + return reader.RangeRead(ctx, httpRange) +} + +// doRefreshLocked 执行实际的刷新逻辑(需要持有锁) +func (r *RefreshableRangeReader) doRefreshLocked(ctx context.Context) error { + if r.refreshCount >= MAX_LINK_REFRESH_COUNT { + return fmt.Errorf("max refresh attempts (%d) reached", MAX_LINK_REFRESH_COUNT) + } + + log.Infof("Link expired, attempting to refresh...") + // Use independent context for refresh to prevent cancellation from affecting link refresh + refreshCtx := context.WithoutCancel(ctx) + newLink, _, refreshErr := r.link.Refresher(refreshCtx) + if refreshErr != nil { + return fmt.Errorf("failed to refresh link: %w", refreshErr) + } + + newLink.Refresher = r.link.Refresher + r.link = newLink + r.innerReader = nil + r.refreshCount++ + + log.Infof("Link refreshed successfully") + return nil +} + +// selfHealingReadCloser wraps an io.ReadCloser and automatically refreshes the link +// if the upstream reader dies before the requested range is fully delivered. +type selfHealingReadCloser struct { + io.ReadCloser + refresher *RefreshableRangeReader + ctx context.Context + httpRange http_range.Range + firstRead bool + bytesRead int64 + closed bool + mu sync.Mutex +} + +func (s *selfHealingReadCloser) Read(p []byte) (n int, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return 0, errors.New("read from closed reader") + } + + n, err = s.ReadCloser.Read(p) + s.bytesRead += int64(n) + wasFirstRead := !s.firstRead + s.firstRead = true + + // Detect 0-byte read on first attempt (indicates link may be expired but returned 200 OK) + if s.shouldReconnectAfterRead(wasFirstRead, n, err) { + if reconnectErr := s.reconnectFromCurrentOffsetLocked(); reconnectErr != nil { + log.Errorf("Failed to refresh link after interrupted read: %v", reconnectErr) + return n, err + } + + if n > 0 { + return n, nil + } + + n, err = s.ReadCloser.Read(p) + s.bytesRead += int64(n) + return n, err + } + + return n, err +} + +func (s *selfHealingReadCloser) shouldReconnectAfterRead(wasFirstRead bool, n int, err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + if s.remainingBytes() <= 0 { + return false + } + + if wasFirstRead && n == 0 && (err == io.EOF || err == io.ErrUnexpectedEOF) { + log.Warnf("Detected 0-byte read on first attempt, attempting to refresh link...") + return true + } + + if errors.Is(err, io.ErrUnexpectedEOF) { + log.Warnf("Detected interrupted read after %d bytes, attempting to refresh link...", s.bytesRead) + return true + } + + errStr := strings.ToLower(err.Error()) + if strings.Contains(errStr, "connection reset by peer") { + log.Warnf("Detected upstream connection reset after %d bytes, attempting to refresh link...", s.bytesRead) + return true + } + + return false +} + +func (s *selfHealingReadCloser) reconnectFromCurrentOffsetLocked() error { + nextRange := s.httpRange + nextRange.Start += s.bytesRead + if nextRange.Length >= 0 { + nextRange.Length -= s.bytesRead + } + + s.refresher.mu.Lock() + refreshErr := s.refresher.doRefreshLocked(s.ctx) + if refreshErr != nil { + s.refresher.mu.Unlock() + return refreshErr + } + + reader, getErr := s.refresher.getInnerReader() + s.refresher.mu.Unlock() + if getErr != nil { + return getErr + } + + newRc, rangeErr := reader.RangeRead(s.ctx, nextRange) + if rangeErr != nil { + return rangeErr + } + + _ = s.ReadCloser.Close() + s.ReadCloser = newRc + log.Infof("Successfully refreshed link and reconnected from offset %d", nextRange.Start) + return nil +} + +func (s *selfHealingReadCloser) remainingBytes() int64 { + length := s.httpRange.Length + if length < 0 || s.httpRange.Start+length > s.refresher.size { + length = s.refresher.size - s.httpRange.Start + } + remaining := length - s.bytesRead + if remaining < 0 { + return 0 + } + return remaining +} + +func (s *selfHealingReadCloser) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return nil + } + s.closed = true + return s.ReadCloser.Close() +} + func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) { + // If link has a Refresher, wrap with RefreshableRangeReader for automatic refresh on expiry + if link.Refresher != nil { + return NewRefreshableRangeReader(link, size), nil + } + if link.RangeReader != nil { if link.Concurrency < 1 && link.PartSize < 1 { return link.RangeReader, nil @@ -98,6 +391,19 @@ func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, } return nil, fmt.Errorf("http request failure, err:%w", err) } + // "Soft 200" expired-link guard: when we asked for ≥1 byte but the + // server promised 0 (e.g. 115 CDN for a stale `?t=` URL — 200 OK + + // Content-Length: 0), the body would be empty and any downstream + // client would interpret it as a corrupt/empty file. Surface this + // as an explicit "expired" error so RefreshableRangeReader can + // trigger a refresh, and callers without a Refresher fail loudly + // instead of streaming silence. ContentLength == -1 means the + // server used chunked transfer encoding and is excluded. + if httpRange.Length > 0 && response.ContentLength == 0 { + response.Body.Close() + return nil, fmt.Errorf("link expired: server returned status %d with Content-Length: 0 (expected %d bytes from %s)", + response.StatusCode, httpRange.Length, link.URL) + } if ServerDownloadLimit != nil { response.Body = &RateLimitReader{ Ctx: ctx, @@ -174,81 +480,204 @@ func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashT return tmpF, hex.EncodeToString(h.Sum(nil)), nil } -type StreamSectionReaderIF interface { +// ReadFullWithRangeRead 使用 RangeRead 从文件流中读取数据到 buf +// file: 文件流 +// buf: 目标缓冲区 +// off: 读取的起始偏移量 +// 返回值: 实际读取的字节数和错误 +// 支持自动重试(最多5次),快速重试策略(1秒、2秒、3秒、4秒、5秒) +// 注意:链接刷新现在由 RefreshableRangeReader 内部的 selfHealingReadCloser 自动处理 +func ReadFullWithRangeRead(file model.FileStreamer, buf []byte, off int64) (int, error) { + length := int64(len(buf)) + var lastErr error + + // 重试最多 MAX_RANGE_READ_RETRY_COUNT 次 + for retry := 0; retry < MAX_RANGE_READ_RETRY_COUNT; retry++ { + reader, err := file.RangeRead(http_range.Range{Start: off, Length: length}) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return 0, err + } + lastErr = fmt.Errorf("RangeRead failed at offset %d: %w", off, err) + log.Debugf("RangeRead retry %d failed: %v", retry+1, lastErr) + // 快速重试:1秒、2秒、3秒、4秒、5秒(连接失败快速重试) + time.Sleep(time.Duration(retry+1) * time.Second) + continue + } + + n, err := io.ReadFull(reader, buf) + if closer, ok := reader.(io.Closer); ok { + closer.Close() + } + + if err == nil { + return n, nil + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return n, err + } + + lastErr = fmt.Errorf("failed to read all data via RangeRead at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) + log.Debugf("RangeRead retry %d read failed: %v", retry+1, lastErr) + + // 快速重试:1秒、2秒、3秒、4秒、5秒(读取失败快速重试) + // 注意:0字节读取导致的链接过期现在由 selfHealingReadCloser 自动处理 + time.Sleep(time.Duration(retry+1) * time.Second) + } + + return 0, lastErr +} + +// StreamHashFile 流式计算文件哈希值,避免将整个文件加载到内存 +// file: 文件流 +// hashType: 哈希算法类型 +// progressWeight: 进度权重(0-100),用于计算整体进度 +// up: 进度回调函数 +func StreamHashFile(file model.FileStreamer, hashType *utils.HashType, progressWeight float64, up *model.UpdateProgress) (string, error) { + // 如果已经有完整缓存文件,直接使用 + if cache := file.GetFile(); cache != nil { + hashFunc := hashType.NewFunc() + cache.Seek(0, io.SeekStart) + _, err := io.Copy(hashFunc, cache) + if err != nil { + return "", err + } + if up != nil && progressWeight > 0 { + (*up)(progressWeight) + } + return hex.EncodeToString(hashFunc.Sum(nil)), nil + } + + hashFunc := hashType.NewFunc() + size := file.GetSize() + chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk + + if _, ok := file.(*SeekableStream); ok { + return streamHashSeekableWithPrefetch(file, hashFunc, size, chunkSize, progressWeight, up) + } + + buf := make([]byte, chunkSize) + var offset int64 = 0 + for offset < size { + readSize := chunkSize + if size-offset < chunkSize { + readSize = size - offset + } + n, err := io.ReadFull(file, buf[:readSize]) + if err != nil { + log.Warnf("StreamHashFile: sequential read failed at offset %d, retrying with RangeRead: %v", offset, err) + n, err = ReadFullWithRangeRead(file, buf[:readSize], offset) + } + if err != nil { + return "", fmt.Errorf("calculate hash failed at offset %d: %w", offset, err) + } + hashFunc.Write(buf[:n]) + offset += int64(n) + if up != nil && progressWeight > 0 { + (*up)(progressWeight * float64(offset) / float64(size)) + } + } + return hex.EncodeToString(hashFunc.Sum(nil)), nil +} + +type hashPrefetchResult struct { + buf []byte + n int + err error +} + +func streamHashSeekableWithPrefetch(file model.FileStreamer, hashFunc io.Writer, size, chunkSize int64, progressWeight float64, up *model.UpdateProgress) (string, error) { + readChunkSize := func(off int64) int64 { + if size-off < chunkSize { + return size - off + } + return chunkSize + } + + var offset int64 + + // Read first chunk synchronously + firstSize := readChunkSize(0) + curBuf := make([]byte, chunkSize) + curN, curErr := ReadFullWithRangeRead(file, curBuf[:firstSize], 0) + if curErr != nil { + return "", fmt.Errorf("calculate hash failed at offset 0: %w", curErr) + } + + for { + nextOff := offset + int64(curN) + + // Launch prefetch for next chunk while we hash current + var prefetchCh chan hashPrefetchResult + if nextOff < size { + prefetchCh = make(chan hashPrefetchResult, 1) + nextSize := readChunkSize(nextOff) + nextBuf := make([]byte, nextSize) + go func(buf []byte, off, sz int64) { + n, err := ReadFullWithRangeRead(file, buf[:sz], off) + prefetchCh <- hashPrefetchResult{buf: buf, n: n, err: err} + }(nextBuf, nextOff, nextSize) + } + + // Hash current chunk + hashFunc.Write(curBuf[:curN]) + offset += int64(curN) + + if up != nil && progressWeight > 0 { + (*up)(progressWeight * float64(offset) / float64(size)) + } + + if prefetchCh == nil { + break + } + + // Wait for prefetch + result := <-prefetchCh + if result.err != nil { + return "", fmt.Errorf("calculate hash failed at offset %d: %w", nextOff, result.err) + } + curBuf = result.buf + curN = result.n + } + + if h, ok := hashFunc.(interface{ Sum([]byte) []byte }); ok { + return hex.EncodeToString(h.Sum(nil)), nil + } + return "", fmt.Errorf("hashFunc does not implement Sum") +} + +type StreamSectionReader interface { // 线程不安全 GetSectionReader(off, length int64) (io.ReadSeeker, error) + // 线程安全 FreeSectionReader(sr io.ReadSeeker) // 线程不安全 DiscardSection(off int64, length int64) error } -func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int, up *model.UpdateProgress) (StreamSectionReaderIF, error) { +func NewStreamSectionReader(file model.FileStreamer, sectionSize int, up *model.UpdateProgress) (StreamSectionReader, error) { if file.GetFile() != nil { return &cachedSectionReader{file.GetFile()}, nil } - maxBufferSize = min(maxBufferSize, int(file.GetSize())) - if maxBufferSize > conf.MaxBufferLimit { - f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return nil, err - } - - if f.Truncate(file.GetSize()) != nil { - // fallback to full cache - _, _ = f.Close(), os.Remove(f.Name()) - cache, err := file.CacheFullAndWriter(up, nil) - if err != nil { - return nil, err - } - return &cachedSectionReader{cache}, nil - } - - ss := &fileSectionReader{file: file, temp: f} - ss.bufPool = &pool.Pool[*offsetWriterWithBase]{ - New: func() *offsetWriterWithBase { - base := ss.tempOffset - ss.tempOffset += int64(maxBufferSize) - return &offsetWriterWithBase{io.NewOffsetWriter(ss.temp, base), base} - }, - } - file.Add(utils.CloseFunc(func() error { - ss.bufPool.Reset() - return errors.Join(ss.temp.Close(), os.Remove(ss.temp.Name())) - })) - return ss, nil - } - - ss := &directSectionReader{file: file} - if conf.MmapThreshold > 0 && maxBufferSize >= conf.MmapThreshold { - ss.bufPool = &pool.Pool[[]byte]{ - New: func() []byte { - buf, err := mmap.Alloc(maxBufferSize) - if err == nil { - file.Add(utils.CloseFunc(func() error { - return mmap.Free(buf) - })) - } else { - buf = make([]byte, maxBufferSize) - } - return buf - }, - } - } else { - ss.bufPool = &pool.Pool[[]byte]{ - New: func() []byte { - return make([]byte, maxBufferSize) - }, - } + blockSize := min(uint64(sectionSize), uint64(file.GetSize()), conf.MaxBlockLimit) + hc, err := hcache.NewHybridCache(blockSize, uint64(file.GetSize())) + if err != nil { + return nil, err } - - file.Add(utils.CloseFunc(func() error { - ss.bufPool.Reset() - return nil - })) + file.Add(hc) + ss := &hybridSectionReader{file: file, hc: hc, fileSize: file.GetSize()} + // Wait for any pending prefetch when the file is closed so we don't + // race against ss.hc being freed. + file.Add(closerFunc(ss.waitPrefetch)) return ss, nil } +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + type cachedSectionReader struct { cache io.ReaderAt } @@ -261,21 +690,52 @@ func (s *cachedSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker } func (*cachedSectionReader) FreeSectionReader(sr io.ReadSeeker) {} -type fileSectionReader struct { +type hybridSectionReader struct { file model.FileStreamer fileOffset int64 - temp *os.File - tempOffset int64 - bufPool *pool.Pool[*offsetWriterWithBase] + fileSize int64 + hc *hcache.HybridCache + mu sync.Mutex + cache []buffer.Block + + // Pass 2 prefetch: while the caller uploads block N, we read block + // N+1 from the source in the background so download/upload overlap. + // Access is serialized through GetSectionReader/DiscardSection which + // are documented as 线程不安全; only the background prefetch goroutine + // touches `prefetch` concurrently with those methods, and waitPrefetch + // drains it before any of them touches ss.file again. + prefetch *prefetchTask } -type offsetWriterWithBase struct { - *io.OffsetWriter - base int64 +type prefetchTask struct { + off int64 // file offset the prefetch started at + length int64 // bytes requested + actual int64 // bytes actually read into block (may be < length on EOF/error) + block buffer.Block // nil if allocation/read failed before any bytes + err error // non-nil on prefetch error + done chan struct{} } // 线程不安全 -func (ss *fileSectionReader) DiscardSection(off int64, length int64) error { +func (ss *hybridSectionReader) DiscardSection(off int64, length int64) error { + // Drain any pending prefetch first so ss.file is quiescent. If the + // prefetched range exactly matches the discard request, we're done — + // the bytes have already been read from the source. + if p := ss.prefetch; p != nil { + <-p.done + ss.prefetch = nil + if p.err == nil && p.off == off && p.actual == length { + if p.block != nil { + ss.put(p.block) + } + ss.fileOffset = off + length + return nil + } + if p.block != nil { + ss.put(p.block) + } + ss.fileOffset = p.off + p.actual + } if off != ss.fileOffset { return fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) } @@ -287,76 +747,171 @@ func (ss *fileSectionReader) DiscardSection(off int64, length int64) error { return nil } -type fileBufferSectionReader struct { +type blockRefReadSeeker struct { io.ReadSeeker - fileBuf *offsetWriterWithBase + b buffer.Block } // 线程不安全 -func (ss *fileSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { +func (ss *hybridSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { + // Try prefetched block first. + if b, actual, err, ok := ss.takePrefetched(off, length); ok { + if err != nil { + if b != nil { + ss.put(b) + } + return nil, fmt.Errorf("prefetch failed at offset %d: %w", off, err) + } + if actual < length { + if b != nil { + ss.put(b) + } + return nil, fmt.Errorf("prefetch short read at offset %d: (expect=%d, actual=%d)", off, length, actual) + } + ss.fileOffset = off + length + ss.schedulePrefetch(off+length, length) + return makeBlockReadSeeker(b, length) + } + if off != ss.fileOffset { return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) } - fileBuf := ss.bufPool.Get() - _, _ = fileBuf.Seek(0, io.SeekStart) - n, err := utils.CopyWithBufferN(fileBuf, ss.file, length) - ss.fileOffset += n - if err != nil { - return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) + b, actual, err := ss.readBlock(length) + if err != nil || actual != length { + if b != nil { + ss.put(b) + } + return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, actual, err) } - return &fileBufferSectionReader{io.NewSectionReader(ss.temp, fileBuf.base, length), fileBuf}, nil + ss.fileOffset += actual + ss.schedulePrefetch(off+length, length) + return makeBlockReadSeeker(b, length) } -func (ss *fileSectionReader) FreeSectionReader(rs io.ReadSeeker) { - if sr, ok := rs.(*fileBufferSectionReader); ok { - ss.bufPool.Put(sr.fileBuf) - sr.fileBuf = nil - sr.ReadSeeker = nil +// readBlock reads `length` bytes from ss.file into a freshly populated +// buffer.Block. The returned block may be nil if no bytes could be read. +// Caller is responsible for returning the block to the pool on error. +func (ss *hybridSectionReader) readBlock(length int64) (buffer.Block, int64, error) { + b := ss.get() + if b == nil { + offset := int64(ss.hc.Size()) + written, err := ss.hc.CopyFromN(ss.file, length) + if written == 0 { + return nil, 0, err + } + b = buffer.NewBlockAdapter( + io.NewOffsetWriter(ss.hc, offset), + io.NewSectionReader(ss.hc, offset, written), + ) + return b, written, err + } + ws := buffer.WriteAtSeekerOf(b) + if _, err := ws.Seek(0, io.SeekStart); err != nil { + ss.put(b) + return nil, 0, fmt.Errorf("failed to reset cached block writer: %w", err) } + written, err := utils.CopyWithBufferN(ws, ss.file, length) + return b, written, err } -type directSectionReader struct { - file model.FileStreamer - fileOffset int64 - bufPool *pool.Pool[[]byte] +// schedulePrefetch starts a background read of `length` bytes at file +// offset `off`. It is a no-op if there is nothing more to read or a +// prefetch is already in flight. +func (ss *hybridSectionReader) schedulePrefetch(off, length int64) { + if length <= 0 || off >= ss.fileSize { + return + } + if ss.prefetch != nil { + return + } + // Clamp to remaining file size so the last partial chunk doesn't + // produce a synthetic short-read error. + if remaining := ss.fileSize - off; remaining < length { + length = remaining + } + task := &prefetchTask{off: off, length: length, done: make(chan struct{})} + ss.prefetch = task + go func() { + defer close(task.done) + b, actual, err := ss.readBlock(length) + task.block = b + task.actual = actual + task.err = err + }() } -// 线程不安全 -func (ss *directSectionReader) DiscardSection(off int64, length int64) error { - if off != ss.fileOffset { - return fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) +// takePrefetched returns the prefetched block if it matches the caller's +// requested offset. The ok return distinguishes "no prefetch present" +// (false) from "prefetch was consumed" (true). +func (ss *hybridSectionReader) takePrefetched(off, length int64) (buffer.Block, int64, error, bool) { + p := ss.prefetch + if p == nil { + return nil, 0, nil, false } - n, err := utils.CopyWithBufferN(io.Discard, ss.file, length) - ss.fileOffset += n - if err != nil { - return fmt.Errorf("failed to skip data: (expect =%d, actual =%d) %w", length, n, err) + <-p.done + ss.prefetch = nil + if p.off != off { + if p.block != nil { + ss.put(p.block) + } + // Source has already advanced by p.actual bytes past p.off. + // Reflect that so subsequent calls see a consistent fileOffset. + ss.fileOffset = p.off + p.actual + return nil, 0, nil, false } - return nil + // Caller may ask for fewer bytes than we prefetched (e.g. final + // partial chunk after we over-prefetched). Allow that. + if length > p.actual && p.err == nil { + // Did not get as many bytes as caller wants and source didn't + // signal an error — surface a short-read for the caller to + // handle, but keep block so it can be returned. + return p.block, p.actual, fmt.Errorf("short read"), true + } + return p.block, p.actual, p.err, true } -type bufferSectionReader struct { - io.ReadSeeker - buf []byte +func (ss *hybridSectionReader) waitPrefetch() error { + if p := ss.prefetch; p != nil { + <-p.done + ss.prefetch = nil + if p.block != nil { + ss.put(p.block) + } + } + return nil } -// 线程不安全 -func (ss *directSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { - if off != ss.fileOffset { - return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) +func makeBlockReadSeeker(b buffer.Block, length int64) (io.ReadSeeker, error) { + if length == b.Size() { + rs := buffer.ReadAtSeekerOf(b) + if _, err := rs.Seek(0, io.SeekStart); err != nil { + return nil, fmt.Errorf("failed to reset cached block reader: %w", err) + } + return &blockRefReadSeeker{rs, b}, nil } - tempBuf := ss.bufPool.Get() - buf := tempBuf[:length] - n, err := io.ReadFull(ss.file, buf) - ss.fileOffset += int64(n) - if int64(n) != length { - return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) + return &blockRefReadSeeker{io.NewSectionReader(b, 0, length), b}, nil +} + +func (ss *hybridSectionReader) get() buffer.Block { + ss.mu.Lock() + defer ss.mu.Unlock() + if len(ss.cache) > 0 { + b := ss.cache[len(ss.cache)-1] + ss.cache = ss.cache[:len(ss.cache)-1] + return b } - return &bufferSectionReader{bytes.NewReader(buf), buf}, nil + return nil +} +func (ss *hybridSectionReader) put(b buffer.Block) { + ss.mu.Lock() + defer ss.mu.Unlock() + ss.cache = append(ss.cache, b) } -func (ss *directSectionReader) FreeSectionReader(rs io.ReadSeeker) { - if sr, ok := rs.(*bufferSectionReader); ok { - ss.bufPool.Put(sr.buf[0:cap(sr.buf)]) - sr.buf = nil + +func (ss *hybridSectionReader) FreeSectionReader(rs io.ReadSeeker) { + if sr, ok := rs.(*blockRefReadSeeker); ok { + ss.put(sr.b) + sr.b = nil sr.ReadSeeker = nil } } diff --git a/internal/stream/util_test.go b/internal/stream/util_test.go new file mode 100644 index 000000000..4bdd1f506 --- /dev/null +++ b/internal/stream/util_test.go @@ -0,0 +1,496 @@ +package stream + +import ( + "bytes" + "context" + "encoding/hex" + "errors" + "io" + "net/http" + "net/http/httptest" + "strconv" + "sync" + "sync/atomic" + "testing" + + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" +) + +// ensureConfForHTTP makes sure conf.Conf is non-nil before any test reaches +// net.HttpClient — that helper deferences conf.Conf.TlsInsecureSkipVerify +// during a sync.Once init and panics on a nil pointer otherwise. +func ensureConfForHTTP() { + if conf.Conf == nil { + conf.Conf = &conf.Config{} + } +} + +func TestRefreshableRangeReader_ReconnectsAfterMidStreamReset(t *testing.T) { + data := []byte("0123456789abcdef") + var refreshes int + var mu sync.Mutex + var resumedRanges []http_range.Range + + initial := RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + return newFlakyReadCloser(sliceForRange(data, httpRange), 5, errors.New("read tcp 127.0.0.1:443: read: connection reset by peer")), nil + }) + resumed := RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + mu.Lock() + resumedRanges = append(resumedRanges, httpRange) + mu.Unlock() + return io.NopCloser(bytes.NewReader(sliceForRange(data, httpRange))), nil + }) + + link := &model.Link{RangeReader: initial} + link.Refresher = func(ctx context.Context) (*model.Link, model.Obj, error) { + refreshes++ + return &model.Link{RangeReader: resumed}, nil, nil + } + + reader, err := NewRefreshableRangeReader(link, int64(len(data))).RangeRead(context.Background(), http_range.Range{Start: 0, Length: int64(len(data))}) + if err != nil { + t.Fatalf("RangeRead() error = %v", err) + } + defer reader.Close() + + got, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if !bytes.Equal(got, data) { + t.Fatalf("ReadAll() = %q, want %q", got, data) + } + if refreshes != 1 { + t.Fatalf("refreshes = %d, want 1", refreshes) + } + + mu.Lock() + defer mu.Unlock() + if len(resumedRanges) != 1 { + t.Fatalf("len(resumedRanges) = %d, want 1", len(resumedRanges)) + } + if resumedRanges[0].Start != 5 { + t.Fatalf("resumed range start = %d, want 5", resumedRanges[0].Start) + } + if resumedRanges[0].Length != int64(len(data)-5) { + t.Fatalf("resumed range length = %d, want %d", resumedRanges[0].Length, len(data)-5) + } +} + +func TestRefreshableRangeReader_ReconnectsAfterMidStreamReset_UnboundedRange(t *testing.T) { + data := []byte("0123456789abcdef") + var refreshes int + var mu sync.Mutex + var resumedRanges []http_range.Range + + initial := RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + return newFlakyReadCloser(sliceForRange(data, httpRange), 5, errors.New("read tcp 127.0.0.1:443: read: connection reset by peer")), nil + }) + resumed := RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + mu.Lock() + resumedRanges = append(resumedRanges, httpRange) + mu.Unlock() + return io.NopCloser(bytes.NewReader(sliceForRange(data, httpRange))), nil + }) + + link := &model.Link{RangeReader: initial} + link.Refresher = func(ctx context.Context) (*model.Link, model.Obj, error) { + refreshes++ + return &model.Link{RangeReader: resumed}, nil, nil + } + + reader, err := NewRefreshableRangeReader(link, int64(len(data))).RangeRead(context.Background(), http_range.Range{Start: 0, Length: -1}) + if err != nil { + t.Fatalf("RangeRead() error = %v", err) + } + defer reader.Close() + + got, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if !bytes.Equal(got, data) { + t.Fatalf("ReadAll() = %q, want %q", got, data) + } + if refreshes != 1 { + t.Fatalf("refreshes = %d, want 1", refreshes) + } + + mu.Lock() + defer mu.Unlock() + if len(resumedRanges) != 1 { + t.Fatalf("len(resumedRanges) = %d, want 1", len(resumedRanges)) + } + if resumedRanges[0].Start != 5 { + t.Fatalf("resumed range start = %d, want 5", resumedRanges[0].Start) + } + if resumedRanges[0].Length != -1 { + t.Fatalf("resumed range length = %d, want -1", resumedRanges[0].Length) + } +} + +// TestSelfHealingReadCloser_NormalEOFDoesNotTriggerReconnect verifies that a +// legitimate io.EOF (all data delivered) does NOT trigger a link refresh. +func TestSelfHealingReadCloser_NormalEOFDoesNotTriggerReconnect(t *testing.T) { + data := []byte("hello world") + refreshes := 0 + + inner := RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(sliceForRange(data, httpRange))), nil + }) + + link := &model.Link{RangeReader: inner} + link.Refresher = func(ctx context.Context) (*model.Link, model.Obj, error) { + refreshes++ + return &model.Link{RangeReader: inner}, nil, nil + } + + rrr := NewRefreshableRangeReader(link, int64(len(data))) + rc, err := rrr.RangeRead(context.Background(), http_range.Range{Start: 0, Length: int64(len(data))}) + if err != nil { + t.Fatalf("RangeRead error: %v", err) + } + defer rc.Close() + + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if !bytes.Equal(got, data) { + t.Fatalf("got %q, want %q", got, data) + } + if refreshes != 0 { + t.Fatalf("refreshes = %d, want 0 (normal EOF should not trigger refresh)", refreshes) + } +} + +// TestSelfHealingReadCloser_UnexpectedEOFTriggersReconnect verifies that +// io.ErrUnexpectedEOF (stream interrupted) DOES trigger reconnect. +func TestSelfHealingReadCloser_UnexpectedEOFTriggersReconnect(t *testing.T) { + data := []byte("0123456789") + refreshes := 0 + + initial := RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + return newFlakyReadCloser(sliceForRange(data, httpRange), 4, io.ErrUnexpectedEOF), nil + }) + resumed := RangeReaderFunc(func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(sliceForRange(data, httpRange))), nil + }) + + link := &model.Link{RangeReader: initial} + link.Refresher = func(ctx context.Context) (*model.Link, model.Obj, error) { + refreshes++ + return &model.Link{RangeReader: resumed}, nil, nil + } + + rrr := NewRefreshableRangeReader(link, int64(len(data))) + rc, err := rrr.RangeRead(context.Background(), http_range.Range{Start: 0, Length: int64(len(data))}) + if err != nil { + t.Fatalf("RangeRead error: %v", err) + } + defer rc.Close() + + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if !bytes.Equal(got, data) { + t.Fatalf("got %q, want %q", got, data) + } + if refreshes != 1 { + t.Fatalf("refreshes = %d, want 1", refreshes) + } +} + +// TestStreamHashFile_SeekablePrefetchProducesSameHash verifies that +// the prefetch optimization in StreamHashFile produces the exact same +// hash as a sequential read. +func TestStreamHashFile_SeekablePrefetchProducesSameHash(t *testing.T) { + // 50 bytes = will be split into 10MB chunks in real code, but we + // override chunkSize for testing. The key point: hash must be identical. + data := []byte("The quick brown fox jumps over the lazy dog!!!!!") // 48 bytes + + rr := RangeReaderFunc(func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(sliceForRange(data, r))), nil + }) + + seekable := &SeekableStream{ + FileStream: &FileStream{ + Obj: &model.Object{Name: "test.bin", Size: int64(len(data))}, + Ctx: context.Background(), + }, + rangeReader: rr, + } + + hash1, err := StreamHashFile(seekable, utils.SHA1, 0, nil) + if err != nil { + t.Fatalf("StreamHashFile error: %v", err) + } + + // Compute expected hash directly + h := utils.SHA1.NewFunc() + h.Write(data) + expected := hex.EncodeToString(h.Sum(nil)) + + if hash1 != expected { + t.Fatalf("hash mismatch: got %s, want %s", hash1, expected) + } +} + +type flakyReadCloser struct { + data []byte + failAfter int + failErr error + failed bool +} + +func newFlakyReadCloser(data []byte, failAfter int, failErr error) *flakyReadCloser { + return &flakyReadCloser{ + data: data, + failAfter: failAfter, + failErr: failErr, + } +} + +func (f *flakyReadCloser) Read(p []byte) (int, error) { + if f.failed { + return 0, io.EOF + } + if f.failAfter >= len(f.data) { + f.failed = true + n := copy(p, f.data) + return n, io.EOF + } + + n := copy(p, f.data[:f.failAfter]) + f.failed = true + return n, f.failErr +} + +func (f *flakyReadCloser) Close() error { + return nil +} + +// TestRangeReaderFromLink_SoftExpiredLink_TriggersRefresh covers the +// 115 CDN "soft 200 + empty body" failure mode: when a signed URL has +// expired, 115 still answers with HTTP 200 and Content-Length: 0 instead +// of a 4xx. Without explicit detection, OP forwards an empty stream to +// the client (mpv sees "Failed to recognize file format"). This test +// pins the contract that GetRangeReaderFromLink-derived readers treat +// "non-zero range requested → zero bytes promised" as an expired link +// and let RefreshableRangeReader trigger a refresh + retry. +func TestRangeReaderFromLink_SoftExpiredLink_TriggersRefresh(t *testing.T) { + ensureConfForHTTP() + data := []byte("The quick brown fox jumps over the lazy dog") + size := int64(len(data)) + + var expiredHits int32 + expired := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&expiredHits, 1) + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) + })) + defer expired.Close() + + var freshHits int32 + fresh := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&freshHits, 1) + rangeHeader := r.Header.Get("Range") + start := int64(0) + length := size + if rangeHeader != "" { + ranges, err := http_range.ParseRange(rangeHeader, size) + if err == nil && len(ranges) == 1 { + start = ranges[0].Start + length = ranges[0].Length + w.Header().Set("Content-Range", ranges[0].ContentRange(size)) + w.Header().Set("Content-Length", strconv.FormatInt(length, 10)) + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write(data[start : start+length]) + return + } + } + w.Header().Set("Content-Length", strconv.FormatInt(length, 10)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data[start : start+length]) + })) + defer fresh.Close() + + link := &model.Link{URL: expired.URL} + var refreshes int32 + link.Refresher = func(ctx context.Context) (*model.Link, model.Obj, error) { + atomic.AddInt32(&refreshes, 1) + return &model.Link{URL: fresh.URL}, nil, nil + } + + rrr, err := GetRangeReaderFromLink(size, link) + if err != nil { + t.Fatalf("GetRangeReaderFromLink: %v", err) + } + rc, err := rrr.RangeRead(context.Background(), http_range.Range{Start: 0, Length: size}) + if err != nil { + t.Fatalf("RangeRead: %v", err) + } + defer rc.Close() + + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(got, data) { + t.Fatalf("body mismatch: got %q, want %q", got, data) + } + if atomic.LoadInt32(&refreshes) != 1 { + t.Fatalf("refreshes = %d, want 1", refreshes) + } + if atomic.LoadInt32(&expiredHits) < 1 { + t.Fatalf("expired server never hit (= %d), expected at least once", expiredHits) + } + if atomic.LoadInt32(&freshHits) < 1 { + t.Fatalf("fresh server never hit (= %d), expected at least once after refresh", freshHits) + } +} + +// TestRangeReaderFromLink_NormalResponse_NoFalsePositive guards against the +// soft-expired check ever firing on a healthy 206 response. +func TestRangeReaderFromLink_NormalResponse_NoFalsePositive(t *testing.T) { + ensureConfForHTTP() + data := []byte("0123456789abcdef") + size := int64(len(data)) + + var refreshes int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ranges, _ := http_range.ParseRange(r.Header.Get("Range"), size) + if len(ranges) == 1 { + ra := ranges[0] + w.Header().Set("Content-Range", ra.ContentRange(size)) + w.Header().Set("Content-Length", strconv.FormatInt(ra.Length, 10)) + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write(data[ra.Start : ra.Start+ra.Length]) + return + } + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + _, _ = w.Write(data) + })) + defer server.Close() + + link := &model.Link{URL: server.URL} + link.Refresher = func(ctx context.Context) (*model.Link, model.Obj, error) { + atomic.AddInt32(&refreshes, 1) + return link, nil, nil + } + + rrr, err := GetRangeReaderFromLink(size, link) + if err != nil { + t.Fatalf("GetRangeReaderFromLink: %v", err) + } + rc, err := rrr.RangeRead(context.Background(), http_range.Range{Start: 4, Length: 6}) + if err != nil { + t.Fatalf("RangeRead: %v", err) + } + defer rc.Close() + + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(got, data[4:10]) { + t.Fatalf("body mismatch: got %q, want %q", got, data[4:10]) + } + if atomic.LoadInt32(&refreshes) != 0 { + t.Fatalf("refreshes = %d, want 0 (healthy response must not trigger refresh)", refreshes) + } +} + +// TestRangeReaderFromLink_ChunkedResponse_NoFalsePositive guards against +// false positives on responses without a Content-Length (chunked transfer). +// Go's http.Response sets ContentLength = -1 for chunked, which must not +// be treated as "0 bytes promised". +func TestRangeReaderFromLink_ChunkedResponse_NoFalsePositive(t *testing.T) { + ensureConfForHTTP() + data := []byte("chunked-payload-bytes-here-yo!") + size := int64(len(data)) + + var refreshes int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Force chunked: do not set Content-Length, write headers then body + // in two flushes. + w.Header().Set("Content-Type", "application/octet-stream") + w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) + _, _ = w.Write(data[:len(data)/2]) + if flusher != nil { + flusher.Flush() + } + _, _ = w.Write(data[len(data)/2:]) + })) + defer server.Close() + + link := &model.Link{URL: server.URL} + link.Refresher = func(ctx context.Context) (*model.Link, model.Obj, error) { + atomic.AddInt32(&refreshes, 1) + return link, nil, nil + } + + rrr, err := GetRangeReaderFromLink(size, link) + if err != nil { + t.Fatalf("GetRangeReaderFromLink: %v", err) + } + rc, err := rrr.RangeRead(context.Background(), http_range.Range{Start: 0, Length: size}) + if err != nil { + t.Fatalf("RangeRead: %v", err) + } + defer rc.Close() + + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(got, data) { + t.Fatalf("body mismatch: got %q, want %q", got, data) + } + if atomic.LoadInt32(&refreshes) != 0 { + t.Fatalf("refreshes = %d, want 0 (chunked response must not be treated as expired)", refreshes) + } +} + +// TestRangeReaderFromLink_SoftExpiredLink_NoRefresher_ReturnsError verifies +// that when a soft-expired link is encountered without a Refresher set, +// the error surfaces cleanly to the caller instead of returning an empty +// body that the client then misinterprets as a corrupt file. +func TestRangeReaderFromLink_SoftExpiredLink_NoRefresher_ReturnsError(t *testing.T) { + ensureConfForHTTP() + size := int64(100) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + link := &model.Link{URL: server.URL} // no Refresher + + rrr, err := GetRangeReaderFromLink(size, link) + if err != nil { + t.Fatalf("GetRangeReaderFromLink: %v", err) + } + _, err = rrr.RangeRead(context.Background(), http_range.Range{Start: 0, Length: size}) + if err == nil { + t.Fatalf("RangeRead returned nil error; expected an 'expired link' error so callers see a real failure instead of an empty stream") + } + if !IsLinkExpiredError(err) { + t.Fatalf("error %q is not classified as expired by IsLinkExpiredError", err) + } +} + +func sliceForRange(data []byte, httpRange http_range.Range) []byte { + start := int(httpRange.Start) + length := int(httpRange.Length) + if httpRange.Length < 0 || httpRange.Start+httpRange.Length > int64(len(data)) { + length = len(data) - start + } + return data[start : start+length] +} diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go new file mode 100644 index 000000000..54d13c56d --- /dev/null +++ b/pkg/buffer/buffer.go @@ -0,0 +1,41 @@ +package buffer + +import ( + "io" +) + +type byteBlock struct { + buf []byte +} + +func NewByteBlock(buf []byte) Block { + return &byteBlock{buf: buf} +} + +func (b *byteBlock) Size() int64 { + return int64(len(b.buf)) +} + +func (b *byteBlock) ReadAt(p []byte, off int64) (n int, err error) { + if len(b.buf) == 0 || off < 0 || off >= b.Size() { + return 0, io.EOF + } + n = copy(p, b.buf[off:]) + if n < len(p) { + err = io.EOF + } + return +} + +func (b *byteBlock) WriteAt(p []byte, off int64) (n int, err error) { + if len(b.buf) == 0 || off < 0 || off >= b.Size() { + return 0, io.ErrShortWrite + } + n = copy(b.buf[off:], p) + if n < len(p) { + err = io.ErrShortWrite + } + return +} + +var _ Block = (*byteBlock)(nil) diff --git a/pkg/buffer/bytes.go b/pkg/buffer/bytes.go deleted file mode 100644 index 3e6cb5405..000000000 --- a/pkg/buffer/bytes.go +++ /dev/null @@ -1,95 +0,0 @@ -package buffer - -import ( - "errors" - "io" -) - -// 用于存储不复用的[]byte -type Reader struct { - bufs [][]byte - size int64 - offset int64 -} - -func (r *Reader) Size() int64 { - return r.size -} - -func (r *Reader) Append(buf []byte) { - r.size += int64(len(buf)) - r.bufs = append(r.bufs, buf) -} - -func (r *Reader) Read(p []byte) (int, error) { - n, err := r.ReadAt(p, r.offset) - if n > 0 { - r.offset += int64(n) - } - return n, err -} - -func (r *Reader) ReadAt(p []byte, off int64) (int, error) { - if off < 0 || off >= r.size { - return 0, io.EOF - } - - n := 0 - readFrom := false - for _, buf := range r.bufs { - if readFrom { - nn := copy(p[n:], buf) - n += nn - if n == len(p) { - return n, nil - } - } else if newOff := off - int64(len(buf)); newOff >= 0 { - off = newOff - } else { - nn := copy(p, buf[off:]) - if nn == len(p) { - return nn, nil - } - n += nn - readFrom = true - } - } - - return n, io.EOF -} - -func (r *Reader) Seek(offset int64, whence int) (int64, error) { - switch whence { - case io.SeekStart: - case io.SeekCurrent: - offset = r.offset + offset - case io.SeekEnd: - offset = r.size + offset - default: - return 0, errors.New("Seek: invalid whence") - } - - if offset < 0 || offset > r.size { - return 0, errors.New("Seek: invalid offset") - } - - r.offset = offset - return offset, nil -} - -func (r *Reader) Reset() { - clear(r.bufs) - r.bufs = nil - r.size = 0 - r.offset = 0 -} - -func NewReader(buf ...[]byte) *Reader { - b := &Reader{ - bufs: make([][]byte, 0, len(buf)), - } - for _, b1 := range buf { - b.Append(b1) - } - return b -} diff --git a/pkg/buffer/file.go b/pkg/buffer/file.go deleted file mode 100644 index 48edf5a4c..000000000 --- a/pkg/buffer/file.go +++ /dev/null @@ -1,88 +0,0 @@ -package buffer - -import ( - "errors" - "io" - "os" -) - -type PeekFile struct { - peek *Reader - file *os.File - offset int64 - size int64 -} - -func (p *PeekFile) Read(b []byte) (n int, err error) { - n, err = p.ReadAt(b, p.offset) - if n > 0 { - p.offset += int64(n) - } - return n, err -} - -func (p *PeekFile) ReadAt(b []byte, off int64) (n int, err error) { - if off < p.peek.Size() { - n, err = p.peek.ReadAt(b, off) - if err == nil || n == len(b) { - return n, nil - } - // EOF - } - var nn int - nn, err = p.file.ReadAt(b[n:], off+int64(n)-p.peek.Size()) - return n + nn, err -} - -func (p *PeekFile) Seek(offset int64, whence int) (int64, error) { - switch whence { - case io.SeekStart: - case io.SeekCurrent: - if offset == 0 { - return p.offset, nil - } - offset = p.offset + offset - case io.SeekEnd: - offset = p.size + offset - default: - return 0, errors.New("Seek: invalid whence") - } - - if offset < 0 || offset > p.size { - return 0, errors.New("Seek: invalid offset") - } - if offset <= p.peek.Size() { - _, err := p.peek.Seek(offset, io.SeekStart) - if err != nil { - return 0, err - } - _, err = p.file.Seek(0, io.SeekStart) - if err != nil { - return 0, err - } - } else { - _, err := p.peek.Seek(p.peek.Size(), io.SeekStart) - if err != nil { - return 0, err - } - _, err = p.file.Seek(offset-p.peek.Size(), io.SeekStart) - if err != nil { - return 0, err - } - } - - p.offset = offset - return offset, nil -} - -func (p *PeekFile) Size() int64 { - return p.size -} - -func NewPeekFile(peek *Reader, file *os.File) (*PeekFile, error) { - stat, err := file.Stat() - if err == nil { - return &PeekFile{peek: peek, file: file, size: stat.Size() + peek.Size()}, nil - } - return nil, err -} diff --git a/pkg/buffer/pipe.go b/pkg/buffer/pipe.go new file mode 100644 index 000000000..194fd58cf --- /dev/null +++ b/pkg/buffer/pipe.go @@ -0,0 +1,157 @@ +package buffer + +import ( + "context" + "fmt" + "io" + "sync" +) + +type PipeBuffer struct { + limit int //expected size + ctx context.Context + offR int + offW int + rw sync.Mutex + block Block + + readSignal chan struct{} + readPending bool +} + +// NewPipeBuffer is a buffer that can have 1 read & 1 write at the same time. +// when read is faster write, immediately feed data to read after written +func NewPipeBuffer(ctx context.Context, block Block) *PipeBuffer { + br := &PipeBuffer{ + ctx: ctx, + limit: int(block.Size()), + readSignal: make(chan struct{}, 1), + block: block, + } + return br +} + +func (br *PipeBuffer) Read(p []byte) (int, error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } + if len(p) == 0 { + return 0, nil + } + if br.offR >= br.limit { + return 0, io.EOF + } + + for { + br.rw.Lock() + if br.block == nil { + br.rw.Unlock() + return 0, io.ErrClosedPipe + } + + if br.offW == br.offR { + br.readPending = true + br.rw.Unlock() + select { + case <-br.ctx.Done(): + return 0, br.ctx.Err() + case _, ok := <-br.readSignal: + if !ok { + return 0, io.ErrClosedPipe + } + continue + } + } + break + } + + canRead := br.offW - br.offR + if canRead < 0 { + br.rw.Unlock() + return 0, io.ErrUnexpectedEOF + } + + off := br.offR + block := br.block + br.rw.Unlock() + + n, err := block.ReadAt(p[:min(len(p), canRead)], int64(off)) + + br.rw.Lock() + br.offR += n + br.rw.Unlock() + + if n < len(p) && br.offR >= br.limit { + return n, io.EOF + } + return n, err +} + +func (br *PipeBuffer) Write(p []byte) (int, error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } + if len(p) == 0 { + return 0, nil + } + + br.rw.Lock() + if br.block == nil { + br.rw.Unlock() + return 0, io.ErrClosedPipe + } + + canWrite := br.limit - br.offW + if canWrite <= 0 { + br.rw.Unlock() + return 0, io.ErrShortWrite + } + + off := br.offW + block := br.block + br.rw.Unlock() + + n, err := block.WriteAt(p[:min(canWrite, len(p))], int64(off)) + + br.rw.Lock() + br.offW += n + if br.readPending { + br.readPending = false + select { + case br.readSignal <- struct{}{}: + default: + } + } + br.rw.Unlock() + + if n < len(p) && err == nil { + return n, io.ErrShortWrite + } + return n, err +} + +func (br *PipeBuffer) Reset(limit int) error { + br.rw.Lock() + defer br.rw.Unlock() + if br.block == nil { + return io.ErrClosedPipe + } + if int64(limit) > br.block.Size() { + return fmt.Errorf("reset limit %d exceeds max size %d", limit, br.block.Size()) + } + br.limit = limit + br.offR = 0 + br.offW = 0 + return nil +} + +func (br *PipeBuffer) Close() error { + br.rw.Lock() + defer br.rw.Unlock() + if br.block != nil { + br.block = nil + br.readPending = false + close(br.readSignal) + } + return nil +} diff --git a/pkg/buffer/type.go b/pkg/buffer/type.go new file mode 100644 index 000000000..ce0d78b2e --- /dev/null +++ b/pkg/buffer/type.go @@ -0,0 +1,24 @@ +package buffer + +import ( + "io" + + "github.com/OpenListTeam/OpenList/v4/internal/model" +) + +type Block interface { + io.ReaderAt + io.WriterAt + Size() int64 +} + +type WriteAtSeeker = model.FileWriter +type WriteAtSeekerProvider interface{ GetWriteAtSeeker() WriteAtSeeker } + +type ReadAtSeeker = model.File +type ReadAtSeekerProvider interface{ GetReadAtSeeker() ReadAtSeeker } + +type SizedReadAtSeeker interface { + ReadAtSeeker + Size() int64 +} diff --git a/pkg/buffer/utils.go b/pkg/buffer/utils.go new file mode 100644 index 000000000..4783f4b70 --- /dev/null +++ b/pkg/buffer/utils.go @@ -0,0 +1,93 @@ +package buffer + +import ( + "errors" + "io" +) + +func WriteAtSeekerOf(b Block) WriteAtSeeker { + if p, ok := b.(WriteAtSeekerProvider); ok { + return p.GetWriteAtSeeker() + } + return io.NewOffsetWriter(b, 0) +} + +// 将一个Block包装为ReadAtSeeker。 +// 固定大小:当前Block的Size()。 +func ReadAtSeekerOf(b Block) ReadAtSeeker { + if p, ok := b.(ReadAtSeekerProvider); ok { + return p.GetReadAtSeeker() + } + return io.NewSectionReader(b, 0, b.Size()) +} + +type blockAdapter struct { + WriteAtSeeker + SizedReadAtSeeker +} + +func (b *blockAdapter) GetWriteAtSeeker() WriteAtSeeker { + return b.WriteAtSeeker +} + +func (b *blockAdapter) GetReadAtSeeker() ReadAtSeeker { + return b.SizedReadAtSeeker +} +func NewBlockAdapter(w WriteAtSeeker, r SizedReadAtSeeker) Block { + return &blockAdapter{ + WriteAtSeeker: w, + SizedReadAtSeeker: r, + } +} + +var _ Block = (*blockAdapter)(nil) + +// 将一个Block包装为ReadAtSeeker。 +// 动态大小:Size() 是动态跟随底层 Block。 +type DynamicReadAtSeeker struct { + block Block + offset int64 +} + +func (r *DynamicReadAtSeeker) ReadAt(p []byte, off int64) (n int, err error) { + return r.block.ReadAt(p, off) +} + +func (r *DynamicReadAtSeeker) Read(p []byte) (n int, err error) { + n, err = r.block.ReadAt(p, r.offset) + if n > 0 { + r.offset += int64(n) + } + return n, err +} + +func (r *DynamicReadAtSeeker) Size() int64 { + return r.block.Size() +} + +func (r *DynamicReadAtSeeker) Seek(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + case io.SeekCurrent: + if offset == 0 { + return r.offset, nil + } + offset = r.offset + offset + case io.SeekEnd: + offset = r.block.Size() + offset + default: + return 0, errors.New("Seek: invalid whence") + } + + if offset < 0 || offset > r.block.Size() { + return 0, errors.New("Seek: invalid offset") + } + r.offset = offset + return offset, nil +} + +func NewDynamicReadAtSeeker(block Block) *DynamicReadAtSeeker { + return &DynamicReadAtSeeker{ + block: block, + } +} diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index ce92cd1fc..01cd736d3 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -3,9 +3,7 @@ package pool import "sync" type Pool[T any] struct { - New func() T - MaxCap int - + New func() T cache []T mu sync.Mutex } @@ -24,9 +22,7 @@ func (p *Pool[T]) Get() T { func (p *Pool[T]) Put(item T) { p.mu.Lock() defer p.mu.Unlock() - if p.MaxCap == 0 || len(p.cache) < int(p.MaxCap) { - p.cache = append(p.cache, item) - } + p.cache = append(p.cache, item) } func (p *Pool[T]) Reset() { @@ -35,3 +31,8 @@ func (p *Pool[T]) Reset() { clear(p.cache) p.cache = nil } + +func (p *Pool[T]) Close() error { + p.Reset() + return nil +} diff --git a/pkg/qbittorrent/client.go b/pkg/qbittorrent/client.go index cc8be8707..e4c12db6f 100644 --- a/pkg/qbittorrent/client.go +++ b/pkg/qbittorrent/client.go @@ -99,12 +99,14 @@ func (c *client) login() error { defer resp.Body.Close() // avoid long waiting time if being upgraded to websocket connections (e.g. 101 responses) - // as per API documentation, qBittorrent returns only 200 on successful login - // so we safely treat any non-200 response as a failure - if resp.StatusCode != http.StatusOK { + // as per API documentation, qBittorrent returns only 200 on successful login (qBittorrent < 5.2.0) + // qBittorrent 5.2.0 /api/v2/auth/login returns HTTP 204 on success + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { return errors.New("failed to login into qBittorrent webui with status code: " + resp.Status) } - + if resp.StatusCode == http.StatusNoContent { + return nil + } // check result body := make([]byte, 2) _, err = resp.Body.Read(body) @@ -180,6 +182,10 @@ func (c *client) AddFromLink(link string, savePath string, id string) error { return err } defer resp.Body.Close() + // qBittorrent 5.2.0 returns 204 on success. + if resp.StatusCode != http.StatusNoContent { + return nil + } // check result body := make([]byte, 2) diff --git a/pkg/torrent/bencode.go b/pkg/torrent/bencode.go new file mode 100644 index 000000000..2d4fd782c --- /dev/null +++ b/pkg/torrent/bencode.go @@ -0,0 +1,261 @@ +package torrent + +import ( + "bytes" + "fmt" + "io" + "sort" + "strconv" +) + +// bencode 编码 + +// BencodeEncode 将值编码为 bencode 格式 +func BencodeEncode(v interface{}) ([]byte, error) { + var buf bytes.Buffer + if err := bencodeEncodeValue(&buf, v); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func bencodeEncodeValue(w io.Writer, v interface{}) error { + switch val := v.(type) { + case int: + return bencodeEncodeInt(w, int64(val)) + case int64: + return bencodeEncodeInt(w, val) + case string: + return bencodeEncodeString(w, val) + case []byte: + return bencodeEncodeBytes(w, val) + case []interface{}: + return bencodeEncodeList(w, val) + case map[string]interface{}: + return bencodeEncodeDict(w, val) + case OrderedDict: + return bencodeEncodeOrderedDict(w, val) + default: + return fmt.Errorf("bencode: unsupported type %T", v) + } +} + +func bencodeEncodeInt(w io.Writer, v int64) error { + _, err := fmt.Fprintf(w, "i%de", v) + return err +} + +func bencodeEncodeString(w io.Writer, v string) error { + _, err := fmt.Fprintf(w, "%d:%s", len(v), v) + return err +} + +func bencodeEncodeBytes(w io.Writer, v []byte) error { + _, err := fmt.Fprintf(w, "%d:", len(v)) + if err != nil { + return err + } + _, err = w.Write(v) + return err +} + +func bencodeEncodeList(w io.Writer, v []interface{}) error { + if _, err := w.Write([]byte("l")); err != nil { + return err + } + for _, item := range v { + if err := bencodeEncodeValue(w, item); err != nil { + return err + } + } + _, err := w.Write([]byte("e")) + return err +} + +func bencodeEncodeDict(w io.Writer, v map[string]interface{}) error { + // bencode 字典要求 key 按字典序排列 + keys := make([]string, 0, len(v)) + for k := range v { + keys = append(keys, k) + } + sort.Strings(keys) + + if _, err := w.Write([]byte("d")); err != nil { + return err + } + for _, k := range keys { + if err := bencodeEncodeString(w, k); err != nil { + return err + } + if err := bencodeEncodeValue(w, v[k]); err != nil { + return err + } + } + _, err := w.Write([]byte("e")) + return err +} + +// OrderedDict 有序字典,保持插入顺序 +type OrderedDict struct { + Keys []string + Values map[string]interface{} +} + +func NewOrderedDict() OrderedDict { + return OrderedDict{ + Keys: make([]string, 0), + Values: make(map[string]interface{}), + } +} + +func (d *OrderedDict) Set(key string, value interface{}) { + if _, exists := d.Values[key]; !exists { + d.Keys = append(d.Keys, key) + } + d.Values[key] = value +} + +func (d *OrderedDict) Get(key string) (interface{}, bool) { + v, ok := d.Values[key] + return v, ok +} + +func bencodeEncodeOrderedDict(w io.Writer, d OrderedDict) error { + // 按字典序排列 key(bencode 规范要求) + keys := make([]string, len(d.Keys)) + copy(keys, d.Keys) + sort.Strings(keys) + + if _, err := w.Write([]byte("d")); err != nil { + return err + } + for _, k := range keys { + if err := bencodeEncodeString(w, k); err != nil { + return err + } + if err := bencodeEncodeValue(w, d.Values[k]); err != nil { + return err + } + } + _, err := w.Write([]byte("e")) + return err +} + +// bencode 解码 + +// BencodeDecode 从字节数组解码 bencode 数据 +func BencodeDecode(data []byte) (interface{}, error) { + reader := bytes.NewReader(data) + val, err := bencodeDecodeValue(reader) + if err != nil { + return nil, err + } + return val, nil +} + +func bencodeDecodeValue(r *bytes.Reader) (interface{}, error) { + b, err := r.ReadByte() + if err != nil { + return nil, err + } + + switch { + case b == 'i': + return bencodeDecodeInt(r) + case b == 'l': + return bencodeDecodeList(r) + case b == 'd': + return bencodeDecodeDict(r) + case b >= '0' && b <= '9': + r.UnreadByte() + return bencodeDecodeString(r) + default: + return nil, fmt.Errorf("bencode: unexpected byte '%c' at position %d", b, int64(r.Len())) + } +} + +func bencodeDecodeInt(r *bytes.Reader) (int64, error) { + var buf bytes.Buffer + for { + b, err := r.ReadByte() + if err != nil { + return 0, err + } + if b == 'e' { + break + } + buf.WriteByte(b) + } + return strconv.ParseInt(buf.String(), 10, 64) +} + +func bencodeDecodeString(r *bytes.Reader) ([]byte, error) { + // 读取长度 + var lenBuf bytes.Buffer + for { + b, err := r.ReadByte() + if err != nil { + return nil, err + } + if b == ':' { + break + } + lenBuf.WriteByte(b) + } + length, err := strconv.ParseInt(lenBuf.String(), 10, 64) + if err != nil { + return nil, fmt.Errorf("bencode: invalid string length: %v", err) + } + if length < 0 || length > 100*1024*1024 { + return nil, fmt.Errorf("bencode: string length out of bounds: %d", length) + } + // Safe to convert to int: bounds check above ensures length <= 100MB which fits in int32 + data := make([]byte, int(length)) + _, err = io.ReadFull(r, data) + if err != nil { + return nil, err + } + return data, nil +} + +func bencodeDecodeList(r *bytes.Reader) ([]interface{}, error) { + var list []interface{} + for { + b, err := r.ReadByte() + if err != nil { + return nil, err + } + if b == 'e' { + return list, nil + } + r.UnreadByte() + val, err := bencodeDecodeValue(r) + if err != nil { + return nil, err + } + list = append(list, val) + } +} + +func bencodeDecodeDict(r *bytes.Reader) (map[string]interface{}, error) { + dict := make(map[string]interface{}) + for { + b, err := r.ReadByte() + if err != nil { + return nil, err + } + if b == 'e' { + return dict, nil + } + r.UnreadByte() + keyBytes, err := bencodeDecodeString(r) + if err != nil { + return nil, err + } + val, err := bencodeDecodeValue(r) + if err != nil { + return nil, err + } + dict[string(keyBytes)] = val + } +} diff --git a/pkg/torrent/generate.go b/pkg/torrent/generate.go new file mode 100644 index 000000000..566cad86f --- /dev/null +++ b/pkg/torrent/generate.go @@ -0,0 +1,123 @@ +package torrent + +import ( + "io" + "os" + "strings" +) + +// GenerateFromFile 从文件路径生成通用的 torrent 文件(不含 CAS 扩展) +// 这是一个通用函数,适用于所有驱动 +func GenerateFromFile(filePath string) ([]byte, error) { + f, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer f.Close() + + info, err := f.Stat() + if err != nil { + return nil, err + } + + return GenerateFromReader(f, info.Name(), info.Size(), DefaultPieceSize) +} + +// GenerateFromReader 从 io.Reader 生成通用的 torrent 文件(不含 CAS 扩展) +// 返回 torrent 字节数据 +func GenerateFromReader(reader io.Reader, fileName string, fileSize int64, pieceSize int64) ([]byte, error) { + if pieceSize <= 0 { + pieceSize = DefaultPieceSize + } + + hw := NewHashWriter(pieceSize, pieceSize) + + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if n > 0 { + hw.Write(buf[:n]) + } + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + } + hw.Finish() + + fileMD5 := hw.GetFileMD5() + pieceHashes := hw.GetPieceHashes() + + t := NewTorrent(fileName, fileSize, fileMD5) + t.Info.PieceLength = pieceSize + t.SetPieces(pieceHashes) + + return t.Encode() +} + +// GenerateFromReaderWithCAS 从 io.Reader 生成包含 CAS 扩展的 torrent 文件 +// 适用于天翼云等支持秒传的网盘 +func GenerateFromReaderWithCAS(reader io.Reader, fileName string, fileSize int64, pieceSize int64) ([]byte, error) { + if pieceSize <= 0 { + pieceSize = DefaultPieceSize + } + + hw := NewHashWriter(pieceSize, pieceSize) + + buf := make([]byte, 32*1024) + for { + n, err := reader.Read(buf) + if n > 0 { + hw.Write(buf[:n]) + } + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + } + hw.Finish() + + fileMD5 := hw.GetFileMD5() + sliceMD5s := hw.GetSliceMD5s() + pieceHashes := hw.GetPieceHashes() + + // 计算 sliceMD5 + sliceMD5 := fileMD5 + if len(sliceMD5s) > 1 { + joined := strings.Join(sliceMD5s, "\n") + sliceMD5 = strings.ToUpper(GetMD5Str(joined)) + } + + t := NewTorrent(fileName, fileSize, fileMD5) + t.Info.PieceLength = pieceSize + t.SetPieces(pieceHashes) + t.SetCASInfo(&CASInfo{ + FileMD5: fileMD5, + SliceMD5: sliceMD5, + SliceMD5s: sliceMD5s, + SliceSize: pieceSize, + Cloud: "189", + }) + + return t.Encode() +} + +// GenerateFromFileWithCAS 从文件路径生成包含 CAS 扩展的 torrent 文件 +func GenerateFromFileWithCAS(filePath string) ([]byte, error) { + f, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer f.Close() + + info, err := f.Stat() + if err != nil { + return nil, err + } + + return GenerateFromReaderWithCAS(f, info.Name(), info.Size(), DefaultPieceSize) +} diff --git a/pkg/torrent/hash_writer.go b/pkg/torrent/hash_writer.go new file mode 100644 index 000000000..a62f42c64 --- /dev/null +++ b/pkg/torrent/hash_writer.go @@ -0,0 +1,229 @@ +package torrent + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "hash" + "io" + "strings" +) + +// HashWriter 同时计算文件的 MD5、分片 MD5 和 SHA-1 piece hash +// 用于在上传过程中一次性计算所有需要的哈希值 +type HashWriter struct { + // 整文件 MD5 + fileMD5 hash.Hash + // 当前分片 MD5 + sliceMD5 hash.Hash + // 当前 piece 的 SHA-1 + pieceSHA1 hash.Hash + + // 分片大小(默认 10MB) + sliceSize int64 + // piece 大小(与 sliceSize 相同,保持对齐) + pieceSize int64 + + // 当前分片已写入字节数 + sliceWritten int64 + // 当前 piece 已写入字节数 + pieceWritten int64 + // 总写入字节数 + totalWritten int64 + + // 每个分片的 MD5(大写十六进制) + sliceMD5Hexs []string + // 所有 piece 的 SHA-1 哈希拼接 + pieceHashes []byte +} + +// NewHashWriter 创建一个新的 HashWriter +// sliceSize: CAS 分片大小(通常 10MB) +// pieceSize: BT piece 大小(设为与 sliceSize 相同以保持对齐) +func NewHashWriter(sliceSize, pieceSize int64) *HashWriter { + if sliceSize <= 0 { + sliceSize = DefaultPieceSize + } + if pieceSize <= 0 { + pieceSize = DefaultPieceSize + } + return &HashWriter{ + fileMD5: md5.New(), + sliceMD5: md5.New(), + pieceSHA1: sha1.New(), + sliceSize: sliceSize, + pieceSize: pieceSize, + } +} + +// NewDefaultHashWriter 创建默认的 HashWriter(10MB 分片) +func NewDefaultHashWriter() *HashWriter { + return NewHashWriter(DefaultPieceSize, DefaultPieceSize) +} + +// Write 实现 io.Writer 接口 +func (hw *HashWriter) Write(p []byte) (n int, err error) { + total := len(p) + offset := 0 + + for offset < total { + // 计算当前可以写入的字节数(取分片和 piece 剩余空间的最小值) + sliceRemain := hw.sliceSize - hw.sliceWritten + pieceRemain := hw.pieceSize - hw.pieceWritten + canWrite := min64(sliceRemain, pieceRemain) + canWrite = min64(canWrite, int64(total-offset)) + + chunk := p[offset : offset+int(canWrite)] + + // 写入整文件 MD5 + hw.fileMD5.Write(chunk) + // 写入当前分片 MD5 + hw.sliceMD5.Write(chunk) + // 写入当前 piece SHA-1 + hw.pieceSHA1.Write(chunk) + + hw.sliceWritten += canWrite + hw.pieceWritten += canWrite + hw.totalWritten += canWrite + offset += int(canWrite) + + // 检查分片是否完成 + if hw.sliceWritten >= hw.sliceSize { + hw.finishSlice() + } + + // 检查 piece 是否完成 + if hw.pieceWritten >= hw.pieceSize { + hw.finishPiece() + } + } + + return total, nil +} + +// finishSlice 完成当前分片的 MD5 计算 +func (hw *HashWriter) finishSlice() { + md5Hex := strings.ToUpper(hex.EncodeToString(hw.sliceMD5.Sum(nil))) + hw.sliceMD5Hexs = append(hw.sliceMD5Hexs, md5Hex) + hw.sliceMD5.Reset() + hw.sliceWritten = 0 +} + +// finishPiece 完成当前 piece 的 SHA-1 计算 +func (hw *HashWriter) finishPiece() { + hw.pieceHashes = append(hw.pieceHashes, hw.pieceSHA1.Sum(nil)...) + hw.pieceSHA1.Reset() + hw.pieceWritten = 0 +} + +// Finish 完成所有哈希计算(处理最后不完整的分片/piece) +func (hw *HashWriter) Finish() { + // 处理最后一个不完整的分片 + if hw.sliceWritten > 0 { + hw.finishSlice() + } + // 处理最后一个不完整的 piece + if hw.pieceWritten > 0 { + hw.finishPiece() + } +} + +// GetFileMD5 获取整文件 MD5(大写十六进制) +func (hw *HashWriter) GetFileMD5() string { + return strings.ToUpper(hex.EncodeToString(hw.fileMD5.Sum(nil))) +} + +// GetSliceMD5s 获取所有分片的 MD5 列表 +func (hw *HashWriter) GetSliceMD5s() []string { + return hw.sliceMD5Hexs +} + +// GetSliceMD5 获取最终的 sliceMD5(用于秒传) +func (hw *HashWriter) GetSliceMD5(fileMD5 string) string { + if len(hw.sliceMD5Hexs) <= 1 { + return fileMD5 + } + joined := strings.Join(hw.sliceMD5Hexs, "\n") + return strings.ToUpper(GetMD5Str(joined)) +} + +// GetPieceHashes 获取所有 piece 的 SHA-1 哈希拼接 +func (hw *HashWriter) GetPieceHashes() []byte { + return hw.pieceHashes +} + +// GetTotalWritten 获取总写入字节数 +func (hw *HashWriter) GetTotalWritten() int64 { + return hw.totalWritten +} + +// BuildTorrent 根据计算结果构建 Torrent 结构 +func (hw *HashWriter) BuildTorrent(fileName string, fileSize int64) *Torrent { + fileMD5 := hw.GetFileMD5() + sliceMD5 := hw.GetSliceMD5(fileMD5) + + t := NewTorrent(fileName, fileSize, fileMD5) + t.SetPieces(hw.GetPieceHashes()) + t.SetCASInfo(&CASInfo{ + FileMD5: fileMD5, + SliceMD5: sliceMD5, + SliceMD5s: hw.GetSliceMD5s(), + SliceSize: hw.sliceSize, + Cloud: "189", + }) + + return t +} + +// BuildTorrentBytes 构建并编码 torrent 文件 +func (hw *HashWriter) BuildTorrentBytes(fileName string, fileSize int64) ([]byte, error) { + t := hw.BuildTorrent(fileName, fileSize) + return t.Encode() +} + +// CopyAndHash 从 reader 读取数据,同时写入 writer 和 HashWriter +func CopyAndHash(dst io.Writer, src io.Reader, hw *HashWriter) (int64, error) { + buf := make([]byte, 32*1024) // 32KB buffer + var written int64 + for { + nr, er := src.Read(buf) + if nr > 0 { + // 写入 HashWriter + hw.Write(buf[:nr]) + // 写入目标 + if dst != nil { + nw, ew := dst.Write(buf[:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = fmt.Errorf("invalid write result") + } + } + written += int64(nw) + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + } else { + written += int64(nr) + } + } + if er != nil { + if er == io.EOF { + break + } + return written, er + } + } + return written, nil +} + +func min64(a, b int64) int64 { + if a < b { + return a + } + return b +} diff --git a/pkg/torrent/torrent.go b/pkg/torrent/torrent.go new file mode 100644 index 000000000..8744e6362 --- /dev/null +++ b/pkg/torrent/torrent.go @@ -0,0 +1,439 @@ +package torrent + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "strings" + "time" +) + +const ( + // DefaultPieceSize 默认分片大小 10MB,与天翼云 CAS 分片大小一致 + DefaultPieceSize int64 = 10 * 1024 * 1024 + + // CASExtensionKey torrent 根字典中的 CAS 扩展 key + CASExtensionKey = "x-cas" + + // CASSliceSizeKey CAS 分片大小 key + CASSliceSizeKey = "slice_size" + // CASSliceMD5sKey 每片 MD5 列表 key + CASSliceMD5sKey = "slice_md5s" + // CASSliceMD5Key 最终 sliceMd5 key + CASSliceMD5Key = "slice_md5" + // CASFileMD5Key 整文件 MD5 key + CASFileMD5Key = "file_md5" + // CASCloudKey 云盘类型 key + CASCloudKey = "cloud" +) + +// CASInfo 天翼云 CAS 秒传所需信息 +type CASInfo struct { + // FileMD5 整文件 MD5(大写十六进制) + FileMD5 string + // SliceMD5 分片 MD5 的摘要(大写十六进制) + SliceMD5 string + // SliceMD5s 每个 10MB 分片的 MD5(大写十六进制) + SliceMD5s []string + // SliceSize 分片大小(字节) + SliceSize int64 + // Cloud 云盘类型标识 + Cloud string +} + +// TorrentFile 表示 torrent 中的单个文件 +type TorrentFile struct { + // Length 文件大小(字节) + Length int64 + // Path 文件路径(多文件模式下的相对路径各段) + Path []string + // MD5Sum 文件的 MD5(可选,BT 标准字段) + MD5Sum string +} + +// TorrentInfo torrent 的 info 字典 +type TorrentInfo struct { + // PieceLength 分片大小 + PieceLength int64 + // Pieces 所有分片的 SHA-1 哈希拼接(每 20 字节一个) + Pieces []byte + // Name 种子名称(单文件模式为文件名,多文件模式为目录名) + Name string + // Length 单文件模式下的文件大小 + Length int64 + // Files 多文件模式下的文件列表 + Files []TorrentFile + // MD5Sum 单文件模式下的文件 MD5(可选) + MD5Sum string +} + +// Torrent 完整的 torrent 文件结构 +type Torrent struct { + // Info info 字典 + Info TorrentInfo + // InfoHash info 字典的 SHA-1 哈希(20 字节) + InfoHash []byte + // Announce tracker URL + Announce string + // AnnounceList tracker 列表 + AnnounceList [][]string + // CreationDate 创建时间 + CreationDate int64 + // Comment 注释 + Comment string + // CreatedBy 创建者 + CreatedBy string + // CAS 天翼云 CAS 扩展信息(存储在 info 字典外部,不影响 info_hash) + CAS *CASInfo +} + +// NewTorrent 创建一个新的 torrent 结构 +func NewTorrent(name string, fileSize int64, fileMD5 string) *Torrent { + return &Torrent{ + Info: TorrentInfo{ + PieceLength: DefaultPieceSize, + Name: name, + Length: fileSize, + MD5Sum: fileMD5, + }, + CreationDate: time.Now().Unix(), + CreatedBy: "OpenList", + Comment: "Generated by OpenList with CAS extension", + } +} + +// SetPieces 设置 SHA-1 分片哈希 +func (t *Torrent) SetPieces(pieces []byte) { + t.Info.Pieces = pieces +} + +// SetCASInfo 设置 CAS 扩展信息 +func (t *Torrent) SetCASInfo(cas *CASInfo) { + t.CAS = cas +} + +// Encode 将 torrent 编码为 bencode 格式的字节数组 +func (t *Torrent) Encode() ([]byte, error) { + // 构建 info 字典 + infoDict := make(map[string]interface{}) + infoDict["piece length"] = int64(t.Info.PieceLength) + infoDict["pieces"] = t.Info.Pieces + infoDict["name"] = t.Info.Name + + if len(t.Info.Files) > 0 { + // 多文件模式 + files := make([]interface{}, 0, len(t.Info.Files)) + for _, f := range t.Info.Files { + fileDict := make(map[string]interface{}) + fileDict["length"] = int64(f.Length) + path := make([]interface{}, 0, len(f.Path)) + for _, p := range f.Path { + path = append(path, p) + } + fileDict["path"] = path + if f.MD5Sum != "" { + fileDict["md5sum"] = f.MD5Sum + } + files = append(files, fileDict) + } + infoDict["files"] = files + } else { + // 单文件模式 + infoDict["length"] = int64(t.Info.Length) + if t.Info.MD5Sum != "" { + infoDict["md5sum"] = t.Info.MD5Sum + } + } + + // 编码 info 字典并计算 info_hash + infoBytes, err := BencodeEncode(infoDict) + if err != nil { + return nil, fmt.Errorf("encode info dict: %w", err) + } + infoHashRaw := sha1.Sum(infoBytes) + t.InfoHash = infoHashRaw[:] + + // 构建根字典 + rootDict := make(map[string]interface{}) + if t.Announce != "" { + rootDict["announce"] = t.Announce + } + if len(t.AnnounceList) > 0 { + announceList := make([]interface{}, 0, len(t.AnnounceList)) + for _, tier := range t.AnnounceList { + tierList := make([]interface{}, 0, len(tier)) + for _, url := range tier { + tierList = append(tierList, url) + } + announceList = append(announceList, tierList) + } + rootDict["announce-list"] = announceList + } + if t.Comment != "" { + rootDict["comment"] = t.Comment + } + if t.CreatedBy != "" { + rootDict["created by"] = t.CreatedBy + } + if t.CreationDate > 0 { + rootDict["creation date"] = t.CreationDate + } + + // info 字典使用原始编码的字节(保证 info_hash 一致) + rootDict["info"] = infoDict + + // CAS 扩展信息(放在 info 外部,不影响 info_hash) + if t.CAS != nil { + casDict := make(map[string]interface{}) + casDict[CASCloudKey] = t.CAS.Cloud + casDict[CASFileMD5Key] = t.CAS.FileMD5 + casDict[CASSliceMD5Key] = t.CAS.SliceMD5 + casDict[CASSliceSizeKey] = t.CAS.SliceSize + + if len(t.CAS.SliceMD5s) > 0 { + md5List := make([]interface{}, 0, len(t.CAS.SliceMD5s)) + for _, md5 := range t.CAS.SliceMD5s { + md5List = append(md5List, md5) + } + casDict[CASSliceMD5sKey] = md5List + } + rootDict[CASExtensionKey] = casDict + } + + return BencodeEncode(rootDict) +} + +// Decode 从 bencode 字节数组解析 torrent +func Decode(data []byte) (*Torrent, error) { + val, err := BencodeDecode(data) + if err != nil { + return nil, fmt.Errorf("bencode decode: %w", err) + } + + rootDict, ok := val.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("torrent: root is not a dict") + } + + t := &Torrent{} + + // 解析 announce + if v, ok := rootDict["announce"]; ok { + if b, ok := v.([]byte); ok { + t.Announce = string(b) + } + } + + // 解析 announce-list + if v, ok := rootDict["announce-list"]; ok { + if list, ok := v.([]interface{}); ok { + for _, tier := range list { + if tierList, ok := tier.([]interface{}); ok { + var urls []string + for _, u := range tierList { + if b, ok := u.([]byte); ok { + urls = append(urls, string(b)) + } + } + if len(urls) > 0 { + t.AnnounceList = append(t.AnnounceList, urls) + } + } + } + } + } + + // 解析 comment + if v, ok := rootDict["comment"]; ok { + if b, ok := v.([]byte); ok { + t.Comment = string(b) + } + } + + // 解析 created by + if v, ok := rootDict["created by"]; ok { + if b, ok := v.([]byte); ok { + t.CreatedBy = string(b) + } + } + + // 解析 creation date + if v, ok := rootDict["creation date"]; ok { + if n, ok := v.(int64); ok { + t.CreationDate = n + } + } + + // 解析 info 字典 + if infoVal, ok := rootDict["info"]; ok { + infoDict, ok := infoVal.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("torrent: info is not a dict") + } + + // 计算 info_hash + infoBytes, err := BencodeEncode(infoDict) + if err != nil { + return nil, fmt.Errorf("encode info for hash: %w", err) + } + infoHashRaw := sha1.Sum(infoBytes) + t.InfoHash = infoHashRaw[:] + + // 解析 info 字段 + if v, ok := infoDict["piece length"]; ok { + if n, ok := v.(int64); ok { + t.Info.PieceLength = n + } + } + if v, ok := infoDict["pieces"]; ok { + if b, ok := v.([]byte); ok { + t.Info.Pieces = b + } + } + if v, ok := infoDict["name"]; ok { + if b, ok := v.([]byte); ok { + t.Info.Name = string(b) + } + } + if v, ok := infoDict["length"]; ok { + if n, ok := v.(int64); ok { + t.Info.Length = n + } + } + if v, ok := infoDict["md5sum"]; ok { + if b, ok := v.([]byte); ok { + t.Info.MD5Sum = string(b) + } + } + + // 解析多文件模式 + if v, ok := infoDict["files"]; ok { + if files, ok := v.([]interface{}); ok { + for _, f := range files { + if fileDict, ok := f.(map[string]interface{}); ok { + tf := TorrentFile{} + if l, ok := fileDict["length"]; ok { + if n, ok := l.(int64); ok { + tf.Length = n + } + } + if p, ok := fileDict["path"]; ok { + if pathList, ok := p.([]interface{}); ok { + for _, pp := range pathList { + if b, ok := pp.([]byte); ok { + tf.Path = append(tf.Path, string(b)) + } + } + } + } + if m, ok := fileDict["md5sum"]; ok { + if b, ok := m.([]byte); ok { + tf.MD5Sum = string(b) + } + } + t.Info.Files = append(t.Info.Files, tf) + } + } + } + } + } + + // 解析 CAS 扩展 + if casVal, ok := rootDict[CASExtensionKey]; ok { + if casDict, ok := casVal.(map[string]interface{}); ok { + cas := &CASInfo{} + if v, ok := casDict[CASCloudKey]; ok { + if b, ok := v.([]byte); ok { + cas.Cloud = string(b) + } + } + if v, ok := casDict[CASFileMD5Key]; ok { + if b, ok := v.([]byte); ok { + cas.FileMD5 = string(b) + } + } + if v, ok := casDict[CASSliceMD5Key]; ok { + if b, ok := v.([]byte); ok { + cas.SliceMD5 = string(b) + } + } + if v, ok := casDict[CASSliceSizeKey]; ok { + if n, ok := v.(int64); ok { + cas.SliceSize = n + } + } + if v, ok := casDict[CASSliceMD5sKey]; ok { + if list, ok := v.([]interface{}); ok { + for _, item := range list { + if b, ok := item.([]byte); ok { + cas.SliceMD5s = append(cas.SliceMD5s, string(b)) + } + } + } + } + t.CAS = cas + } + } + + return t, nil +} + +// GetInfoHashHex 获取 info_hash 的十六进制字符串 +func (t *Torrent) GetInfoHashHex() string { + return hex.EncodeToString(t.InfoHash) +} + +// GetPieceHashes 获取所有分片的 SHA-1 哈希(每个 20 字节) +func (t *Torrent) GetPieceHashes() [][]byte { + if len(t.Info.Pieces) == 0 { + return nil + } + count := len(t.Info.Pieces) / 20 + hashes := make([][]byte, count) + for i := 0; i < count; i++ { + hashes[i] = t.Info.Pieces[i*20 : (i+1)*20] + } + return hashes +} + +// GetTotalSize 获取 torrent 中所有文件的总大小 +func (t *Torrent) GetTotalSize() int64 { + if len(t.Info.Files) > 0 { + var total int64 + for _, f := range t.Info.Files { + total += f.Length + } + return total + } + return t.Info.Length +} + +// HasCASInfo 检查 torrent 是否包含 CAS 扩展信息 +func (t *Torrent) HasCASInfo() bool { + return t.CAS != nil && t.CAS.FileMD5 != "" && t.CAS.SliceMD5 != "" +} + +// BuildCASInfoFromMD5s 从分片 MD5 列表构建 CAS 信息 +func BuildCASInfoFromMD5s(fileMD5 string, sliceMD5s []string, sliceSize int64) *CASInfo { + sliceMD5 := fileMD5 + if len(sliceMD5s) > 1 { + // 所有分片 MD5 用 \n 拼接后再取 MD5 + joined := strings.Join(sliceMD5s, "\n") + sliceMD5 = strings.ToUpper(GetMD5Str(joined)) + } + return &CASInfo{ + FileMD5: fileMD5, + SliceMD5: sliceMD5, + SliceMD5s: sliceMD5s, + SliceSize: sliceSize, + Cloud: "189", + } +} + +// GetMD5Str 计算字符串的 MD5(大写十六进制) +func GetMD5Str(data string) string { + h := md5.New() + h.Write([]byte(data)) + return strings.ToUpper(hex.EncodeToString(h.Sum(nil))) +} diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go index 596e61e54..c4b4e735f 100644 --- a/pkg/utils/hash.go +++ b/pkg/utils/hash.go @@ -90,6 +90,12 @@ var ( // SHA256 indicates SHA-256 support SHA256 = RegisterHash("sha256", "SHA-256", 64, sha256.New) + + // SHA1_128K is SHA1 of first 128KB, used by 115 driver for rapid upload + SHA1_128K = RegisterHash("sha1_128k", "SHA1-128K", 40, sha1.New) + + // PRE_HASH is SHA1 of first 1024 bytes, used by Aliyundrive for rapid upload + PRE_HASH = RegisterHash("pre_hash", "PRE-HASH", 40, sha1.New) ) // HashData get hash of one hashType diff --git a/public/dist/README.md b/public/dist/README.md deleted file mode 100644 index d8709fb57..000000000 --- a/public/dist/README.md +++ /dev/null @@ -1 +0,0 @@ -## Put dist of frontend here. \ No newline at end of file diff --git a/server/common/common.go b/server/common/common.go index d78051268..a1a9a6319 100644 --- a/server/common/common.go +++ b/server/common/common.go @@ -128,13 +128,24 @@ func Pluralize(count int, singular, plural string) string { return plural } -func GinWithValue(c *gin.Context, keyAndValue ...any) { +type requestContext struct { + context.Context +} + +// GinAppendValues 向当前请求上下文追加键值,提供类似 gin.Context Set/Get 的可变语义。 +// 同一请求内,已持有的上下文引用会同步看到后续更新。 +func GinAppendValues(c *gin.Context, keyAndValue ...any) { + ctx := c.Request.Context() + if r, ok := ctx.(*requestContext); ok { + r.Context = ContentWithValues(r.Context, keyAndValue...) + return + } c.Request = c.Request.WithContext( - ContentWithValue(c.Request.Context(), keyAndValue...), + &requestContext{ContentWithValues(ctx, keyAndValue...)}, ) } -func ContentWithValue(ctx context.Context, keyAndValue ...any) context.Context { +func ContentWithValues(ctx context.Context, keyAndValue ...any) context.Context { if len(keyAndValue) < 1 || len(keyAndValue)%2 != 0 { panic("keyAndValue must be an even number of arguments (key, value, ...)") } diff --git a/server/handles/archive.go b/server/handles/archive.go index d46f83c86..364e93edc 100644 --- a/server/handles/archive.go +++ b/server/handles/archive.go @@ -105,7 +105,7 @@ func FsArchiveMeta(c *gin.Context, req *ArchiveMetaReq, user *model.User) { common.ErrorResp(c, err, 500, true) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return @@ -188,7 +188,7 @@ func FsArchiveList(c *gin.Context, req *ArchiveListReq, user *model.User) { common.ErrorResp(c, err, 500, true) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return diff --git a/server/handles/fsbatch.go b/server/handles/fsbatch.go index 28588d668..e672165cd 100644 --- a/server/handles/fsbatch.go +++ b/server/handles/fsbatch.go @@ -49,7 +49,7 @@ func FsRecursiveMove(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - common.GinWithValue(c, conf.MetaKey, srcMeta) + common.GinAppendValues(c, conf.MetaKey, srcMeta) dstDir, err := user.JoinPath(req.DstDir) if err != nil { @@ -183,7 +183,7 @@ func FsBatchRename(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) for _, renameObject := range req.RenameObjects { if renameObject.SrcName == "" || renameObject.NewName == "" { continue @@ -236,7 +236,7 @@ func FsRegexRename(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) srcRegexp, err := regexp.Compile(req.SrcNameRegex) if err != nil { diff --git a/server/handles/fsmanage.go b/server/handles/fsmanage.go index 94e7c3fec..d97d36d24 100644 --- a/server/handles/fsmanage.go +++ b/server/handles/fsmanage.go @@ -119,6 +119,7 @@ func FsMove(c *gin.Context) { req.Names[i] = "" continue } + req.Names[i] = srcPath if !req.Overwrite { base := stdpath.Base(srcPath) if base == "." || base == "/" { @@ -129,12 +130,11 @@ func FsMove(c *gin.Context) { if !req.SkipExisting { common.ErrorStrResp(c, fmt.Sprintf("file [%s] exists", name), 403) return - } else { - continue } + req.Names[i] = "" + continue } } - req.Names[i] = srcPath } // Create all tasks immediately without any synchronous validation @@ -222,6 +222,7 @@ func FsCopy(c *gin.Context) { req.Names[i] = "" continue } + req.Names[i] = srcPath if !req.Overwrite { base := stdpath.Base(srcPath) if base == "." || base == "/" { @@ -233,11 +234,11 @@ func FsCopy(c *gin.Context) { common.ErrorStrResp(c, fmt.Sprintf("file [%s] exists", name), 403) return } else if !req.Merge || !res.IsDir() { + req.Names[i] = "" continue } } } - req.Names[i] = srcPath } // Create all tasks immediately without any synchronous validation @@ -425,7 +426,7 @@ func FsRemoveEmptyDirectory(c *gin.Context) { common.ErrorResp(c, errs.PermissionDenied, 403) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) rootFiles, err := fs.List(c.Request.Context(), srcDir, &fs.ListArgs{}) if err != nil { diff --git a/server/handles/fsmanage_test.go b/server/handles/fsmanage_test.go new file mode 100644 index 000000000..e2de6508c --- /dev/null +++ b/server/handles/fsmanage_test.go @@ -0,0 +1,100 @@ +package handles + +import ( + "testing" +) + +// TestClearSkippedNames verifies the contract from upstream fix #2520: +// When SkipExisting is true and a file already exists at the destination, +// that name must be cleared (set to "") so the task loop skips it. +// Valid (non-skipped) names must be set to full srcPath. +// +// This is a contract test for the name-filtering loop in FsMove/FsCopy. +// It tests the logic in isolation without needing gin/fs dependencies. +func TestClearSkippedNames(t *testing.T) { + // Simulate the name-filtering loop logic from FsMove/FsCopy. + // existsAtDst simulates whether a file already exists at the destination. + filterNames := func(srcDir string, names []string, overwrite, skipExisting bool, existsAtDst func(name string) bool) []string { + result := make([]string, len(names)) + copy(result, names) + for i, name := range result { + srcPath := srcDir + "/" + name + // First: set to srcPath (the fix moves this before skip check) + result[i] = srcPath + if !overwrite { + if existsAtDst(name) { + if !skipExisting { + // Would return error in real code + result[i] = "ERROR" + return result + } + // Skip: must clear to "" + result[i] = "" + continue + } + } + } + return result + } + + tests := []struct { + name string + srcDir string + names []string + overwrite bool + skipExisting bool + existsAtDst func(string) bool + want []string + }{ + { + name: "skip existing files clears their names", + srcDir: "/src", + names: []string{"a.mkv", "b.mkv", "c.mkv"}, + overwrite: false, + skipExisting: true, + existsAtDst: func(n string) bool { return n == "b.mkv" }, + want: []string{"/src/a.mkv", "", "/src/c.mkv"}, + }, + { + name: "no skipping when overwrite is true", + srcDir: "/src", + names: []string{"a.mkv", "b.mkv"}, + overwrite: true, + skipExisting: false, + existsAtDst: func(string) bool { return true }, + want: []string{"/src/a.mkv", "/src/b.mkv"}, + }, + { + name: "all files skipped results in all empty", + srcDir: "/src", + names: []string{"x.mp4", "y.mp4"}, + overwrite: false, + skipExisting: true, + existsAtDst: func(string) bool { return true }, + want: []string{"", ""}, + }, + { + name: "no existing files keeps all paths", + srcDir: "/data", + names: []string{"movie.mkv"}, + overwrite: false, + skipExisting: true, + existsAtDst: func(string) bool { return false }, + want: []string{"/data/movie.mkv"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := filterNames(tt.srcDir, tt.names, tt.overwrite, tt.skipExisting, tt.existsAtDst) + if len(got) != len(tt.want) { + t.Fatalf("length mismatch: got %d, want %d", len(got), len(tt.want)) + } + for i := range tt.want { + if got[i] != tt.want[i] { + t.Errorf("names[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +} diff --git a/server/handles/fsread.go b/server/handles/fsread.go index a90fc1082..8fd731af7 100644 --- a/server/handles/fsread.go +++ b/server/handles/fsread.go @@ -88,7 +88,7 @@ func FsList(c *gin.Context, req *ListReq, user *model.User) { common.ErrorResp(c, err, 500, true) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return @@ -152,7 +152,7 @@ func FsDirs(c *gin.Context) { common.ErrorResp(c, err, 500, true) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return @@ -230,6 +230,10 @@ func toObjsResp(objs []model.Obj, parent string, encrypt bool) []ObjResp { for _, obj := range objs { thumb, _ := model.GetThumb(obj) mountDetails, _ := model.GetStorageDetails(obj) + hashInfo := obj.GetHash().Export() + if hashInfo == nil { + hashInfo = make(map[*utils.HashType]string) + } resp = append(resp, ObjResp{ Name: obj.GetName(), Size: obj.GetSize(), @@ -237,7 +241,7 @@ func toObjsResp(objs []model.Obj, parent string, encrypt bool) []ObjResp { Modified: obj.ModTime(), Created: obj.CreateTime(), HashInfoStr: obj.GetHash().String(), - HashInfo: obj.GetHash().Export(), + HashInfo: hashInfo, Sign: common.Sign(obj, parent, encrypt), Thumb: thumb, Type: utils.GetObjType(obj.GetName(), obj.IsDir()), @@ -291,7 +295,7 @@ func FsGet(c *gin.Context, req *FsGetReq, user *model.User) { common.ErrorResp(c, err, 500, true) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, reqPath, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return @@ -415,7 +419,7 @@ func FsOther(c *gin.Context) { common.ErrorResp(c, err, 500) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) if !common.CanAccess(user, meta, req.Path, req.Password) { common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) return diff --git a/server/handles/fsup.go b/server/handles/fsup.go index 0f46398cd..54cdb4fee 100644 --- a/server/handles/fsup.go +++ b/server/handles/fsup.go @@ -93,6 +93,12 @@ func FsStream(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + if sha1_128k := c.GetHeader("X-File-Sha1-128k"); sha1_128k != "" { + h[utils.SHA1_128K] = sha1_128k + } + if preHash := c.GetHeader("X-File-Pre-Hash"); preHash != "" { + h[utils.PRE_HASH] = preHash + } mimetype := c.GetHeader("Content-Type") if len(mimetype) == 0 { mimetype = utils.GetMimeType(name) @@ -190,6 +196,12 @@ func FsForm(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + if sha1_128k := c.GetHeader("X-File-Sha1-128k"); sha1_128k != "" { + h[utils.SHA1_128K] = sha1_128k + } + if preHash := c.GetHeader("X-File-Pre-Hash"); preHash != "" { + h[utils.PRE_HASH] = preHash + } mimetype := file.Header.Get("Content-Type") if len(mimetype) == 0 { mimetype = utils.GetMimeType(name) diff --git a/server/handles/torrent.go b/server/handles/torrent.go new file mode 100644 index 000000000..8b6ee1b6b --- /dev/null +++ b/server/handles/torrent.go @@ -0,0 +1,433 @@ +package handles + +import ( + "encoding/base64" + "fmt" + "io" + "strings" + + _189pc "github.com/OpenListTeam/OpenList/v4/drivers/189pc" + "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/pkg/torrent" + "github.com/OpenListTeam/OpenList/v4/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +// maxTorrentBase64Len is the max allowed Base64-encoded torrent size (~10MB decoded) +const maxTorrentBase64Len = 14 * 1024 * 1024 + +// maxTorrentGenFileSize is the max file size allowed for synchronous torrent generation (1GB) +const maxTorrentGenFileSize = 1 * 1024 * 1024 * 1024 + +// validateParsedTorrent checks that basic torrent invariants hold. +func validateParsedTorrent(t *torrent.Torrent) error { + if len(t.Info.Pieces)%20 != 0 { + return fmt.Errorf("torrent pieces 数据无效:长度必须为 20 的整数倍") + } + return nil +} + +// ParseTorrentReq 解析 torrent 文件请求 +type ParseTorrentReq struct { + // TorrentData Base64 编码的 torrent 文件内容 + TorrentData string `json:"torrent_data" binding:"required"` +} + +// ParseTorrentResp 解析 torrent 文件响应 +type ParseTorrentResp struct { + // Name 种子名称 + Name string `json:"name"` + // TotalSize 总大小 + TotalSize int64 `json:"total_size"` + // PieceLength 分片大小 + PieceLength int64 `json:"piece_length"` + // PieceCount 分片数量 + PieceCount int `json:"piece_count"` + // InfoHash info_hash(十六进制) + InfoHash string `json:"info_hash"` + // Files 文件列表(多文件模式) + Files []TorrentFileInfo `json:"files"` + // HasCAS 是否包含 CAS 扩展信息 + HasCAS bool `json:"has_cas"` + // CAS CAS 扩展信息 + CAS *CASInfoResp `json:"cas,omitempty"` +} + +// TorrentFileInfo torrent 中的文件信息 +type TorrentFileInfo struct { + Path string `json:"path"` + Size int64 `json:"size"` +} + +// CASInfoResp CAS 信息响应 +type CASInfoResp struct { + FileMD5 string `json:"file_md5"` + SliceMD5 string `json:"slice_md5"` + SliceSize int64 `json:"slice_size"` + Cloud string `json:"cloud"` +} + +// ParseTorrent 解析 torrent 文件,返回文件列表等信息 +func ParseTorrent(c *gin.Context) { + var req ParseTorrentReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + + // 限制 Base64 输入大小(最大 ~10MB decoded) + if len(req.TorrentData) > maxTorrentBase64Len { + common.ErrorResp(c, fmt.Errorf("torrent 数据过大(最大 10MB)"), 400) + return + } + + // Base64 解码 + torrentData, err := base64.StdEncoding.DecodeString(req.TorrentData) + if err != nil { + common.ErrorResp(c, fmt.Errorf("无效的 Base64 编码: %w", err), 400) + return + } + + // 解析 torrent + t, err := torrent.Decode(torrentData) + if err != nil { + common.ErrorResp(c, fmt.Errorf("解析 torrent 失败: %w", err), 400) + return + } + if err := validateParsedTorrent(t); err != nil { + common.ErrorResp(c, err, 400) + return + } + + resp := ParseTorrentResp{ + Name: t.Info.Name, + TotalSize: t.GetTotalSize(), + PieceLength: t.Info.PieceLength, + PieceCount: len(t.Info.Pieces) / 20, + InfoHash: t.GetInfoHashHex(), + HasCAS: t.HasCASInfo(), + } + + // 文件列表 + if len(t.Info.Files) > 0 { + resp.Files = make([]TorrentFileInfo, 0, len(t.Info.Files)) + for _, f := range t.Info.Files { + resp.Files = append(resp.Files, TorrentFileInfo{ + Path: strings.Join(f.Path, "/"), + Size: f.Length, + }) + } + } else { + // 单文件模式 + resp.Files = []TorrentFileInfo{ + {Path: t.Info.Name, Size: t.Info.Length}, + } + } + + // CAS 信息 + if t.HasCASInfo() { + resp.CAS = &CASInfoResp{ + FileMD5: t.CAS.FileMD5, + SliceMD5: t.CAS.SliceMD5, + SliceSize: t.CAS.SliceSize, + Cloud: t.CAS.Cloud, + } + } + + common.SuccessResp(c, resp) +} + +// TorrentRapidUploadReq 从 torrent 秒传请求 +type TorrentRapidUploadReq struct { + // TorrentData Base64 编码的 torrent 文件内容 + TorrentData string `json:"torrent_data" binding:"required"` + // Path 目标路径 + Path string `json:"path" binding:"required"` +} + +// TorrentRapidUpload 从 torrent 文件中提取 CAS 信息尝试秒传到天翼云 +func TorrentRapidUpload(c *gin.Context) { + user := c.Request.Context().Value(conf.UserKey).(*model.User) + + var req TorrentRapidUploadReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + + // 检查权限 + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + if !common.CanWrite(user, meta, reqPath) { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + + // Base64 解码 + torrentData, err := base64.StdEncoding.DecodeString(req.TorrentData) + if err != nil { + common.ErrorResp(c, fmt.Errorf("无效的 Base64 编码: %w", err), 400) + return + } + + // 解析 torrent + t, err := torrent.Decode(torrentData) + if err != nil { + common.ErrorResp(c, fmt.Errorf("解析 torrent 失败: %w", err), 400) + return + } + + if !t.HasCASInfo() { + common.ErrorResp(c, fmt.Errorf("torrent 不包含 CAS 扩展信息,无法秒传"), 400) + return + } + + // 获取目标存储 + storage, dstDirActualPath, err := op.GetStorageAndActualPath(reqPath) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + + // 获取目标目录对象 + dstDir, err := op.Get(c.Request.Context(), storage, dstDirActualPath) + if err != nil { + common.ErrorResp(c, fmt.Errorf("获取目标目录失败: %w", err), 500) + return + } + if !dstDir.IsDir() { + common.ErrorResp(c, errs.NotFolder, 400) + return + } + + // 检查是否是天翼云 PC 驱动 + cloud189PC, ok := storage.(*_189pc.Cloud189PC) + if !ok { + common.ErrorResp(c, fmt.Errorf("目标存储不是天翼云PC驱动,不支持 CAS 秒传"), 400) + return + } + + // 尝试秒传 + obj, err := cloud189PC.RapidUploadFromTorrent(c.Request.Context(), dstDir, torrentData, true) + if err != nil { + common.ErrorResp(c, fmt.Errorf("秒传失败: %w", err), 400) + return + } + + common.SuccessResp(c, gin.H{ + "message": "秒传成功", + "file_name": obj.GetName(), + "file_size": obj.GetSize(), + }) +} + +// UploadTorrentAndParse 通过文件上传方式解析 torrent +func UploadTorrentAndParse(c *gin.Context) { + file, err := c.FormFile("torrent") + if err != nil { + common.ErrorResp(c, fmt.Errorf("获取上传文件失败: %w", err), 400) + return + } + + // 限制文件大小(最大 10MB) + if file.Size > 10*1024*1024 { + common.ErrorResp(c, fmt.Errorf("torrent 文件过大(最大 10MB)"), 400) + return + } + + f, err := file.Open() + if err != nil { + common.ErrorResp(c, fmt.Errorf("打开文件失败: %w", err), 500) + return + } + defer f.Close() + + torrentData, err := io.ReadAll(f) + if err != nil { + common.ErrorResp(c, fmt.Errorf("读取文件失败: %w", err), 500) + return + } + + // 解析 torrent + t, err := torrent.Decode(torrentData) + if err != nil { + common.ErrorResp(c, fmt.Errorf("解析 torrent 失败: %w", err), 400) + return + } + if err := validateParsedTorrent(t); err != nil { + common.ErrorResp(c, err, 400) + return + } + + resp := ParseTorrentResp{ + Name: t.Info.Name, + TotalSize: t.GetTotalSize(), + PieceLength: t.Info.PieceLength, + PieceCount: len(t.Info.Pieces) / 20, + InfoHash: t.GetInfoHashHex(), + HasCAS: t.HasCASInfo(), + } + + // 文件列表 + if len(t.Info.Files) > 0 { + resp.Files = make([]TorrentFileInfo, 0, len(t.Info.Files)) + for _, f := range t.Info.Files { + resp.Files = append(resp.Files, TorrentFileInfo{ + Path: strings.Join(f.Path, "/"), + Size: f.Length, + }) + } + } else { + resp.Files = []TorrentFileInfo{ + {Path: t.Info.Name, Size: t.Info.Length}, + } + } + + // CAS 信息 + if t.HasCASInfo() { + resp.CAS = &CASInfoResp{ + FileMD5: t.CAS.FileMD5, + SliceMD5: t.CAS.SliceMD5, + SliceSize: t.CAS.SliceSize, + Cloud: t.CAS.Cloud, + } + } + + // 同时返回 Base64 编码的 torrent 数据,方便后续使用 + common.SuccessResp(c, gin.H{ + "info": resp, + "torrent_data": base64.StdEncoding.EncodeToString(torrentData), + }) +} + +// GenerateTorrentReq 为指定路径的文件生成 torrent 请求 +type GenerateTorrentReq struct { + // Path 文件在 OpenList 中的路径 + Path string `json:"path" binding:"required"` + // WithCAS 是否注入 CAS 扩展信息(仅天翼云需要) + WithCAS bool `json:"with_cas"` +} + +// GenerateTorrentForPath 为指定路径的文件生成 torrent +// 这是一个通用接口,适用于所有驱动 +// 会获取文件内容计算哈希,然后生成 torrent +func GenerateTorrentForPath(c *gin.Context) { + user := c.Request.Context().Value(conf.UserKey).(*model.User) + + var req GenerateTorrentReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + + // 检查读取权限 + meta, err := op.GetNearestMeta(reqPath) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + if !common.CanRead(user, meta, reqPath) { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + + // 获取存储和文件信息 + storage, actualPath, err := op.GetStorageAndActualPath(reqPath) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + + // with_cas 仅支持天翼云PC驱动 + if req.WithCAS { + if _, is189pc := storage.(*_189pc.Cloud189PC); !is189pc { + common.ErrorResp(c, fmt.Errorf("CAS 秒传扩展仅支持天翼云PC驱动"), 400) + return + } + } + + // 获取文件对象 + obj, err := op.Get(c.Request.Context(), storage, actualPath) + if err != nil { + common.ErrorResp(c, fmt.Errorf("获取文件失败: %w", err), 500) + return + } + if obj.IsDir() { + common.ErrorResp(c, fmt.Errorf("不支持为目录生成 torrent"), 400) + return + } + + // 限制可生成 torrent 的文件大小 + if obj.GetSize() > maxTorrentGenFileSize { + common.ErrorResp(c, fmt.Errorf("文件过大,无法生成 torrent(最大 1GB)"), 400) + return + } + + // 获取文件下载链接 + link, _, err := op.Link(c.Request.Context(), storage, actualPath, model.LinkArgs{}) + if err != nil { + common.ErrorResp(c, fmt.Errorf("获取文件链接失败: %w", err), 500) + return + } + defer link.Close() + + // 通过 RangeReader 获取文件内容并计算哈希生成 torrent + if link.RangeReader == nil { + common.ErrorResp(c, fmt.Errorf("该存储不支持流式读取,无法生成 torrent(请先下载文件到本地)"), 400) + return + } + + // 读取整个文件 + rc, err := link.RangeReader.RangeRead(c.Request.Context(), http_range.Range{Length: obj.GetSize()}) + if err != nil { + common.ErrorResp(c, fmt.Errorf("读取文件失败: %w", err), 500) + return + } + defer rc.Close() + + var torrentData []byte + if req.WithCAS { + torrentData, err = torrent.GenerateFromReaderWithCAS(rc, obj.GetName(), obj.GetSize(), torrent.DefaultPieceSize) + } else { + torrentData, err = torrent.GenerateFromReader(rc, obj.GetName(), obj.GetSize(), torrent.DefaultPieceSize) + } + if err != nil { + common.ErrorResp(c, fmt.Errorf("生成 torrent 失败: %w", err), 500) + return + } + + // 解析生成的 torrent 获取 info_hash + t, _ := torrent.Decode(torrentData) + var infoHash string + if t != nil { + infoHash = t.GetInfoHashHex() + } + + common.SuccessResp(c, gin.H{ + "torrent_data": base64.StdEncoding.EncodeToString(torrentData), + "info_hash": infoHash, + "file_name": obj.GetName() + ".torrent", + "size": len(torrentData), + "with_cas": req.WithCAS, + }) +} diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 0fc243617..ca67e4a6d 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -24,7 +24,7 @@ func Auth(allowDisabledGuest bool) func(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, admin) + common.GinAppendValues(c, conf.UserKey, admin) log.Debugf("use admin token: %+v", admin) c.Next() return @@ -41,7 +41,7 @@ func Auth(allowDisabledGuest bool) func(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, guest) + common.GinAppendValues(c, conf.UserKey, guest) log.Debugf("use empty token: %+v", guest) c.Next() return @@ -69,7 +69,7 @@ func Auth(allowDisabledGuest bool) func(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, user) + common.GinAppendValues(c, conf.UserKey, user) log.Debugf("use login token: %+v", user) c.Next() } @@ -84,7 +84,7 @@ func Authn(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, admin) + common.GinAppendValues(c, conf.UserKey, admin) log.Debugf("use admin token: %+v", admin) c.Next() return @@ -96,7 +96,7 @@ func Authn(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, guest) + common.GinAppendValues(c, conf.UserKey, guest) log.Debugf("use empty token: %+v", guest) c.Next() return @@ -124,7 +124,7 @@ func Authn(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, user) + common.GinAppendValues(c, conf.UserKey, user) log.Debugf("use login token: %+v", user) c.Next() } diff --git a/server/middlewares/check.go b/server/middlewares/check.go index c7203a490..c5e07874c 100644 --- a/server/middlewares/check.go +++ b/server/middlewares/check.go @@ -29,7 +29,7 @@ func StoragesLoaded(c *gin.Context) { return } } - common.GinWithValue(c, + common.GinAppendValues(c, conf.ApiUrlKey, common.GetApiUrlFromRequest(c.Request), ) c.Next() diff --git a/server/middlewares/down.go b/server/middlewares/down.go index c1f81b54b..d71be00b3 100644 --- a/server/middlewares/down.go +++ b/server/middlewares/down.go @@ -17,7 +17,7 @@ import ( func PathParse(c *gin.Context) { rawPath := parsePath(c.Param("path")) - common.GinWithValue(c, conf.PathKey, rawPath) + common.GinAppendValues(c, conf.PathKey, rawPath) c.Next() } @@ -29,7 +29,7 @@ func Down(verifyFunc func(string, string) error) func(c *gin.Context) { common.ErrorPage(c, err, 500, true) return } - common.GinWithValue(c, conf.MetaKey, meta) + common.GinAppendValues(c, conf.MetaKey, meta) // verify sign if needSign(meta, rawPath) { s := c.Query("sign") diff --git a/server/middlewares/sharing.go b/server/middlewares/sharing.go index d7549202f..aa0cab0c6 100644 --- a/server/middlewares/sharing.go +++ b/server/middlewares/sharing.go @@ -8,11 +8,11 @@ import ( func SharingIdParse(c *gin.Context) { sid := c.Param("sid") - common.GinWithValue(c, conf.SharingIDKey, sid) + common.GinAppendValues(c, conf.SharingIDKey, sid) c.Next() } func EmptyPathParse(c *gin.Context) { - common.GinWithValue(c, conf.PathKey, "/") + common.GinAppendValues(c, conf.PathKey, "/") c.Next() } diff --git a/server/router.go b/server/router.go index 57d1166ae..03285581d 100644 --- a/server/router.go +++ b/server/router.go @@ -217,6 +217,11 @@ func _fs(g *gin.RouterGroup) { // g.POST("/add_transmission", handles.SetTransmission) g.POST("/add_offline_download", handles.AddOfflineDownload) g.POST("/archive/decompress", handles.FsArchiveDecompress) + // Torrent 相关接口 + g.POST("/torrent/parse", handles.ParseTorrent) + g.POST("/torrent/upload_parse", handles.UploadTorrentAndParse) + g.POST("/torrent/rapid_upload", handles.TorrentRapidUpload) + g.POST("/torrent/generate", handles.GenerateTorrentForPath) // Direct upload (client-side upload to storage) g.POST("/get_direct_upload_info", middlewares.FsUp, handles.FsGetDirectUploadInfo) } @@ -242,6 +247,10 @@ func Cors(r *gin.Engine) { config.AllowHeaders = conf.Conf.Cors.AllowHeaders config.AllowMethods = conf.Conf.Cors.AllowMethods r.Use(cors.New(config)) + r.Use(func(c *gin.Context) { + c.Header("Cross-Origin-Opener-Policy", "same-origin") + c.Header("Cross-Origin-Embedder-Policy", "credentialless") + }) } func InitS3(e *gin.Engine) { diff --git a/server/static/static.go b/server/static/static.go index 29f97ff74..81695c9fb 100644 --- a/server/static/static.go +++ b/server/static/static.go @@ -8,10 +8,13 @@ import ( "io/fs" "net/http" "os" + "path/filepath" "strings" + "sync" "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/conf" + "github.com/OpenListTeam/OpenList/v4/internal/frontend" "github.com/OpenListTeam/OpenList/v4/internal/setting" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/public" @@ -32,21 +35,41 @@ type Manifest struct { Icons []ManifestIcon `json:"icons"` } -var static fs.FS +// reloadableFS wraps fs.FS with thread-safe swapping. +// This allows gin StaticFS routes (which capture the fs.FS at registration time) +// to serve updated files after a watcher-triggered reload. +type reloadableFS struct { + mu sync.RWMutex + current fs.FS +} + +func (r *reloadableFS) Open(name string) (fs.File, error) { + r.mu.RLock() + current := r.current + r.mu.RUnlock() + return current.Open(name) +} + +func (r *reloadableFS) swap(f fs.FS) { + r.mu.Lock() + r.current = f + r.mu.Unlock() +} + +var staticFS = &reloadableFS{} func initStatic() { - utils.Log.Debug("Initializing static file system...") - if conf.Conf.DistDir == "" { - dist, err := fs.Sub(public.Public, "dist") - if err != nil { - utils.Log.Fatalf("failed to read dist dir: %v", err) - } - static = dist - utils.Log.Debug("Using embedded dist directory") + if conf.Conf.DistDir != "" { + staticFS.swap(os.DirFS(conf.Conf.DistDir)) + utils.Log.Infof("Using custom dist directory: %s", conf.Conf.DistDir) return } - static = os.DirFS(conf.Conf.DistDir) - utils.Log.Infof("Using custom dist directory: %s", conf.Conf.DistDir) + dist, err := fs.Sub(public.Public, "dist") + if err != nil { + utils.Log.Fatalf("failed to read embedded dist dir: %v", err) + } + staticFS.swap(dist) + utils.Log.Infof("Using embedded dist directory") } func replaceStrings(content string, replacements map[string]string) string { @@ -74,7 +97,7 @@ func initIndex(siteConfig SiteConfig) { utils.Log.Info("Successfully fetched index.html from CDN") } else { utils.Log.Debug("Reading index.html from static files system...") - indexFile, err := static.Open("index.html") + indexFile, err := staticFS.Open("index.html") if err != nil { if errors.Is(err, fs.ErrNotExist) { utils.Log.Fatalf("index.html not exist, you may forget to put dist of frontend to public/dist") @@ -131,13 +154,23 @@ func UpdateIndex() { utils.Log.Debug("Index.html update completed") } +// ReloadStatic reloads the static files from disk (called by the watcher after an update) +func ReloadStatic() { + utils.Log.Info("[static] reloading static files after frontend update...") + distPath := filepath.Join(frontend.GetDistPath(), "dist") + staticFS.swap(os.DirFS(distPath)) + utils.Log.Infof("Switched to dynamically fetched dist: %s", distPath) + siteConfig := getSiteConfig() + initIndex(siteConfig) +} + func ManifestJSON(c *gin.Context) { // Get site configuration to ensure consistent base path handling siteConfig := getSiteConfig() - + // Get site title from settings siteTitle := setting.GetStr(conf.SiteTitle) - + // Get logo from settings, use the first line (light theme logo) logoSetting := setting.GetStr(conf.Logo) logoUrl := strings.Split(logoSetting, "\n")[0] @@ -167,7 +200,7 @@ func ManifestJSON(c *gin.Context) { c.Header("Content-Type", "application/json") c.Header("Cache-Control", "public, max-age=3600") // cache for 1 hour - + if err := json.NewEncoder(c.Writer).Encode(manifest); err != nil { utils.Log.Errorf("Failed to encode manifest.json: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate manifest"}) @@ -180,8 +213,12 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { siteConfig := getSiteConfig() initStatic() initIndex(siteConfig) + + // Start the frontend watcher for periodic updates + frontend.StartWatcher(ReloadStatic) + folders := []string{"assets", "images", "streamer", "static"} - + if conf.Conf.Cdn == "" { utils.Log.Debug("Setting up static file serving...") r.Use(func(c *gin.Context) { @@ -192,7 +229,7 @@ func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { } }) for _, folder := range folders { - sub, err := fs.Sub(static, folder) + sub, err := fs.Sub(staticFS, folder) if err != nil { utils.Log.Fatalf("can't find folder: %s", folder) } diff --git a/server/webdav.go b/server/webdav.go index a949068f0..74523b0b3 100644 --- a/server/webdav.go +++ b/server/webdav.go @@ -54,7 +54,7 @@ func WebDAVAuth(c *gin.Context) { count, cok := model.LoginCache.Get(ip) if cok && count >= model.DefaultMaxAuthRetries { if c.Request.Method == "OPTIONS" { - common.GinWithValue(c, conf.UserKey, guest) + common.GinAppendValues(c, conf.UserKey, guest) c.Next() return } @@ -78,13 +78,13 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, admin) + common.GinAppendValues(c, conf.UserKey, admin) c.Next() return } } if c.Request.Method == "OPTIONS" { - common.GinWithValue(c, conf.UserKey, guest) + common.GinAppendValues(c, conf.UserKey, guest) c.Next() return } @@ -96,7 +96,7 @@ func WebDAVAuth(c *gin.Context) { user, ok := tryLogin(username, password) if !ok { if c.Request.Method == "OPTIONS" { - common.GinWithValue(c, conf.UserKey, guest) + common.GinAppendValues(c, conf.UserKey, guest) c.Next() return } @@ -109,7 +109,7 @@ func WebDAVAuth(c *gin.Context) { model.LoginCache.Del(ip) if user.Disabled || !user.CanWebdavRead() { if c.Request.Method == "OPTIONS" { - common.GinWithValue(c, conf.UserKey, guest) + common.GinAppendValues(c, conf.UserKey, guest) c.Next() return } @@ -142,11 +142,11 @@ func WebDAVAuth(c *gin.Context) { c.Abort() return } - common.GinWithValue(c, conf.UserKey, user) + common.GinAppendValues(c, conf.UserKey, user) if user.IsGuest() { - common.GinWithValue(c, conf.MetaPassKey, password) + common.GinAppendValues(c, conf.MetaPassKey, password) } else { - common.GinWithValue(c, conf.MetaPassKey, "") + common.GinAppendValues(c, conf.MetaPassKey, "") } c.Next() }