Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cuslines/cuda_python/cu_propagate_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,5 +256,40 @@ def _yield_slines():

return _yield_slines()

def gen_bin_indices(self, bin_starts, bin_len):
bin_edges = np.append(bin_starts, bin_starts[-1] + bin_len)
bin_indices = {k: [] for k in range(len(bin_starts))}

for ii in range(self.ngpus):
scaled_lens = self.sline_lens[ii] * self.gpu_tracker.step_size
assignments = np.digitize(scaled_lens, bin_edges) - 1

for k in range(len(bin_starts)):
jj_indices = np.where(assignments == k)[0]
if jj_indices.size > 0:
bin_indices[k].append((ii, jj_indices))

return bin_indices

def as_array_sequence_group(self, bin_indices, bin_start):
relevant_blocks = bin_indices[bin_start]

Comment on lines +259 to +276
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gen_bin_indices builds bin_indices as a dict keyed by integer indices 0..len(bin_starts)-1, but as_array_sequence_group indexes bin_indices using the bin_start argument directly; when callers pass the physical bin start values (e.g., 20, 30, ...), this will not match the integer keys and will lead to KeyError or empty groups. The keys should be made consistent (either indices everywhere or the actual bin_start values), for example by keying bin_indices by the bin start values inside gen_bin_indices and keeping as_array_sequence_group's interface as-is.

Copilot uses AI. Check for mistakes.
def _yield_slines():
for ii, jj_array in relevant_blocks:
gpu_slines = self.slines[ii]
gpu_lens = self.sline_lens[ii]
for jj in jj_array:
npts = gpu_lens[jj]
yield np.asarray(gpu_slines[jj], dtype=REAL_DTYPE)[:npts]

def _get_buffer_size():
total_pts = 0
for ii, jj_array in relevant_blocks:
total_pts += np.sum(self.sline_lens[ii][jj_array])

return math.ceil((total_pts * 3 * REAL_SIZE) / MEGABYTE)

return ArraySequence(_yield_slines(), _get_buffer_size())

def as_array_sequence(self):
return ArraySequence(self.as_generator(), self.get_buffer_size())
71 changes: 71 additions & 0 deletions cuslines/cuda_python/cu_tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,74 @@ def generate_trx(self, seeds, ref_img):
trx_file.resize()

return trx_file

def generate_trx_grouped_by_len(self, seeds, ref_img, min_len=20, max_len=250, bin_len=10):
global_chunk_sz, nchunks = self._divide_chunks(seeds)

# Will resize by a factor of 2 if these are exceeded
sl_per_seed_guess = 3
n_sls_guess = sl_per_seed_guess * seeds.shape[0]


bin_starts = np.arange(min_len, max_len + bin_len, bin_len)
trx_files = {}
offsets_idxs = {}
sls_data_idxs = {}
for bin_start in bin_starts:
max_steps = (bin_start + bin_len) / self.step_size
trx_files[bin_start] = TrxFile(
reference=ref_img,
nb_streamlines=n_sls_guess,
nb_vertices=n_sls_guess * max_steps,
)
Comment on lines +329 to +335
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_steps = (bin_start + bin_len) / self.step_size will generally be a float, so nb_vertices=n_sls_guess * max_steps passes a non-integer value where TrxFile expects a vertex count (an integer) and may use it to size numpy arrays/memmaps. To avoid type errors or under-allocation, consider computing an integer upper bound for the number of steps per streamline (e.g., using math.ceil on the division) and passing that integer to nb_vertices.

Copilot uses AI. Check for mistakes.
trx_files[bin_start].streamlines._offsets = \
trx_files[bin_start].streamlines._offsets.astype(np.uint64)
offsets_idxs[bin_start] = 0
sls_data_idxs[bin_start] = 0

with tqdm(total=seeds.shape[0]) as pbar:
for idx in range(int(nchunks)):
self.seed_propagator.propagate(
seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz]
)
bin_indices = self.seed_propagator.gen_bin_indices(bin_starts, bin_len)
for bin_start in bin_starts:
tractogram = Tractogram(
self.seed_propagator.as_array_sequence_group(bin_indices, bin_start),
Comment on lines +347 to +349
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here bin_indices = self.seed_propagator.gen_bin_indices(bin_starts, bin_len) returns a dict keyed by integer bin indices (0..len(bin_starts)-1), but the subsequent loop uses the physical bin_start values (e.g., 20, 30, ...) as keys when calling as_array_sequence_group, which internally does bin_indices[bin_start]. This mismatch in key types means lookups into bin_indices will fail or return no streamlines; the binning logic needs to use consistent keys (either indices or bin start values) across gen_bin_indices, generate_trx_grouped_by_len, and as_array_sequence_group.

Suggested change
for bin_start in bin_starts:
tractogram = Tractogram(
self.seed_propagator.as_array_sequence_group(bin_indices, bin_start),
for bin_idx, bin_start in enumerate(bin_starts):
tractogram = Tractogram(
self.seed_propagator.as_array_sequence_group(bin_indices, bin_idx),

Copilot uses AI. Check for mistakes.
affine_to_rasmm=ref_img.affine,
)
tractogram.to_world()
sls = tractogram.streamlines

new_offsets_idx = offsets_idxs[bin_start] + len(sls._offsets)
new_sls_data_idx = sls_data_idxs[bin_start] + len(sls._data)

if (
new_offsets_idx > trx_files[bin_start].header["NB_STREAMLINES"]
or new_sls_data_idx > trx_files[bin_start].header["NB_VERTICES"]
):
logger.info("TRX resizing...")
trx_files[bin_start].resize(
nb_streamlines=new_offsets_idx * 2,
nb_vertices=new_sls_data_idx * 2,
)

# TRX uses memmaps here
trx_files[bin_start].streamlines._data[sls_data_idxs[bin_start]:new_sls_data_idx] = sls._data
trx_files[bin_start].streamlines._offsets[offsets_idxs[bin_start]:new_offsets_idx] = (
sls_data_idxs[bin_start] + sls._offsets
)
trx_files[bin_start].streamlines._lengths[offsets_idxs[bin_start]:new_offsets_idx] = (
sls._lengths
)

offsets_idxs[bin_start] = new_offsets_idx
sls_data_idxs[bin_start] = new_sls_data_idx
pbar.update(
seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz].shape[0]
)

for bin_start in bin_starts:
trx_files[bin_start].resize()

return trx_files
Loading