diff --git a/cuslines/cuda_python/cu_propagate_seeds.py b/cuslines/cuda_python/cu_propagate_seeds.py index c1d8a8b..cbe4fb1 100644 --- a/cuslines/cuda_python/cu_propagate_seeds.py +++ b/cuslines/cuda_python/cu_propagate_seeds.py @@ -256,5 +256,40 @@ def _yield_slines(): return _yield_slines() + def gen_bin_indices(self, bin_starts, bin_len): + bin_edges = np.append(bin_starts, bin_starts[-1] + bin_len) + bin_indices = {k: [] for k in range(len(bin_starts))} + + for ii in range(self.ngpus): + scaled_lens = self.sline_lens[ii] * self.gpu_tracker.step_size + assignments = np.digitize(scaled_lens, bin_edges) - 1 + + for k in range(len(bin_starts)): + jj_indices = np.where(assignments == k)[0] + if jj_indices.size > 0: + bin_indices[k].append((ii, jj_indices)) + + return bin_indices + + def as_array_sequence_group(self, bin_indices, bin_start): + relevant_blocks = bin_indices[bin_start] + + def _yield_slines(): + for ii, jj_array in relevant_blocks: + gpu_slines = self.slines[ii] + gpu_lens = self.sline_lens[ii] + for jj in jj_array: + npts = gpu_lens[jj] + yield np.asarray(gpu_slines[jj], dtype=REAL_DTYPE)[:npts] + + def _get_buffer_size(): + total_pts = 0 + for ii, jj_array in relevant_blocks: + total_pts += np.sum(self.sline_lens[ii][jj_array]) + + return math.ceil((total_pts * 3 * REAL_SIZE) / MEGABYTE) + + return ArraySequence(_yield_slines(), _get_buffer_size()) + def as_array_sequence(self): return ArraySequence(self.as_generator(), self.get_buffer_size()) diff --git a/cuslines/cuda_python/cu_tractography.py b/cuslines/cuda_python/cu_tractography.py index 7064cc9..466a326 100644 --- a/cuslines/cuda_python/cu_tractography.py +++ b/cuslines/cuda_python/cu_tractography.py @@ -313,3 +313,74 @@ def generate_trx(self, seeds, ref_img): trx_file.resize() return trx_file + + def generate_trx_grouped_by_len(self, seeds, ref_img, min_len=20, max_len=250, bin_len=10): + global_chunk_sz, nchunks = self._divide_chunks(seeds) + + # Will resize by a factor of 2 if these are exceeded + sl_per_seed_guess = 3 + n_sls_guess = sl_per_seed_guess * seeds.shape[0] + + + bin_starts = np.arange(min_len, max_len + bin_len, bin_len) + trx_files = {} + offsets_idxs = {} + sls_data_idxs = {} + for bin_start in bin_starts: + max_steps = (bin_start + bin_len) / self.step_size + trx_files[bin_start] = TrxFile( + reference=ref_img, + nb_streamlines=n_sls_guess, + nb_vertices=n_sls_guess * max_steps, + ) + trx_files[bin_start].streamlines._offsets = \ + trx_files[bin_start].streamlines._offsets.astype(np.uint64) + offsets_idxs[bin_start] = 0 + sls_data_idxs[bin_start] = 0 + + with tqdm(total=seeds.shape[0]) as pbar: + for idx in range(int(nchunks)): + self.seed_propagator.propagate( + seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz] + ) + bin_indices = self.seed_propagator.gen_bin_indices(bin_starts, bin_len) + for bin_start in bin_starts: + tractogram = Tractogram( + self.seed_propagator.as_array_sequence_group(bin_indices, bin_start), + affine_to_rasmm=ref_img.affine, + ) + tractogram.to_world() + sls = tractogram.streamlines + + new_offsets_idx = offsets_idxs[bin_start] + len(sls._offsets) + new_sls_data_idx = sls_data_idxs[bin_start] + len(sls._data) + + if ( + new_offsets_idx > trx_files[bin_start].header["NB_STREAMLINES"] + or new_sls_data_idx > trx_files[bin_start].header["NB_VERTICES"] + ): + logger.info("TRX resizing...") + trx_files[bin_start].resize( + nb_streamlines=new_offsets_idx * 2, + nb_vertices=new_sls_data_idx * 2, + ) + + # TRX uses memmaps here + trx_files[bin_start].streamlines._data[sls_data_idxs[bin_start]:new_sls_data_idx] = sls._data + trx_files[bin_start].streamlines._offsets[offsets_idxs[bin_start]:new_offsets_idx] = ( + sls_data_idxs[bin_start] + sls._offsets + ) + trx_files[bin_start].streamlines._lengths[offsets_idxs[bin_start]:new_offsets_idx] = ( + sls._lengths + ) + + offsets_idxs[bin_start] = new_offsets_idx + sls_data_idxs[bin_start] = new_sls_data_idx + pbar.update( + seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz].shape[0] + ) + + for bin_start in bin_starts: + trx_files[bin_start].resize() + + return trx_files