.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "_auto_examples/brain_tractograms/tract_io.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download__auto_examples_brain_tractograms_tract_io.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr__auto_examples_brain_tractograms_tract_io.py:


Input-Output with brain tractograms
==========================================

.. GENERATED FROM PYTHON SOURCE LINES 6-319

.. code-block:: Python



    import argparse
    import os.path
    import sys
    import pdb
    import numpy as np
    from dipy.tracking.streamline import set_number_of_points

    import vtk
    from vtk.util import numpy_support as ns

    try:
        # Python 2
        from itertools import izip
    except ImportError:
        # Python 3
        izip = zip
    try:
        # Python 2
        xrange
    except NameError:
        # Python 3, xrange is now named range
        xrange = range


    def read_vtk(filename):
        if filename.endswith("xml") or filename.endswith("vtp"):
            polydata_reader = vtk.vtkXMLPolyDataReader()
        else:
            polydata_reader = vtk.vtkPolyDataReader()

        polydata_reader.SetFileName(filename)
        polydata_reader.Update()

        polydata = polydata_reader.GetOutput()

        return vtkPolyData_to_tracts(polydata)


    def vtkPolyData_to_tracts(polydata):
        result = {}
        result["lines"] = ns.vtk_to_numpy(polydata.GetLines().GetData())
        result["points"] = ns.vtk_to_numpy(polydata.GetPoints().GetData())
        result["numberOfLines"] = polydata.GetNumberOfLines()

        data = {}
        if polydata.GetPointData().GetScalars():
            data["ActiveScalars"] = polydata.GetPointData().GetScalars().GetName()
            result["Scalars"] = polydata.GetPointData().GetScalars()
        if polydata.GetPointData().GetVectors():
            data["ActiveVectors"] = polydata.GetPointData().GetVectors().GetName()
        if polydata.GetPointData().GetTensors():
            data["ActiveTensors"] = polydata.GetPointData().GetTensors().GetName()

        for i in xrange(polydata.GetPointData().GetNumberOfArrays()):
            array = polydata.GetPointData().GetArray(i)
            np_array = ns.vtk_to_numpy(array)
            if np_array.ndim == 1:
                np_array = np_array.reshape(len(np_array), 1)
            data[polydata.GetPointData().GetArrayName(i)] = np_array

        result["pointData"] = data

        tracts, data = vtkPolyData_dictionary_to_tracts_and_data(result)
        return tracts, data


    def vtkPolyData_dictionary_to_tracts_and_data(dictionary):
        dictionary_keys = set(("lines", "points", "numberOfLines"))
        if not dictionary_keys.issubset(dictionary.keys()):
            raise ValueError(
                "Dictionary must have the keys lines and points" + repr(dictionary.keys())
            )

        tract_data = {}
        tracts = []

        lines = np.asarray(dictionary["lines"]).squeeze()
        points = dictionary["points"]

        actual_line_index = 0
        number_of_tracts = dictionary["numberOfLines"]
        original_lines = []
        for l in xrange(number_of_tracts):
            tracts.append(
                points[
                    lines[
                        actual_line_index
                        + 1 : actual_line_index
                        + lines[actual_line_index]
                        + 1
                    ]
                ]
            )
            original_lines.append(
                np.array(
                    lines[
                        actual_line_index
                        + 1 : actual_line_index
                        + lines[actual_line_index]
                        + 1
                    ],
                    copy=True,
                )
            )
            actual_line_index += lines[actual_line_index] + 1

        if "pointData" in dictionary:
            point_data_keys = [
                it[0]
                for it in dictionary["pointData"].items()
                if isinstance(it[1], np.ndarray)
            ]

            for k in point_data_keys:
                array_data = dictionary["pointData"][k]
                if not k in tract_data:
                    tract_data[k] = [array_data[f] for f in original_lines]
                else:
                    np.vstack(tract_data[k])
                    tract_data[k].extend(
                        [array_data[f] for f in original_lines[-number_of_tracts:]]
                    )

        return tracts, tract_data


    def save_vtk(filename, tracts, lines_indices=None, scalars=None):
        lengths = [len(p) for p in tracts]
        line_starts = ns.numpy.r_[0, ns.numpy.cumsum(lengths)]
        if lines_indices is None:
            lines_indices = [
                ns.numpy.arange(length) + line_start
                for length, line_start in izip(lengths, line_starts)
            ]

        ids = ns.numpy.hstack(
            [ns.numpy.r_[c[0], c[1]] for c in izip(lengths, lines_indices)]
        )
        vtk_ids = ns.numpy_to_vtkIdTypeArray(ids, deep=True)

        cell_array = vtk.vtkCellArray()
        cell_array.SetCells(len(tracts), vtk_ids)
        points = ns.numpy.vstack(tracts).astype(
            ns.get_vtk_to_numpy_typemap()[vtk.VTK_DOUBLE]
        )
        points_array = ns.numpy_to_vtk(points, deep=True)

        poly_data = vtk.vtkPolyData()
        vtk_points = vtk.vtkPoints()
        vtk_points.SetData(points_array)
        poly_data.SetPoints(vtk_points)
        poly_data.SetLines(cell_array)
        poly_data.BuildCells()

        if filename.endswith(".xml") or filename.endswith(".vtp"):
            writer = vtk.vtkXMLPolyDataWriter()
            writer.SetDataModeToBinary()
        else:
            writer = vtk.vtkPolyDataWriter()
            writer.SetFileTypeToBinary()

        writer.SetFileName(filename)
        if hasattr(vtk, "VTK_MAJOR_VERSION") and vtk.VTK_MAJOR_VERSION > 5:
            writer.SetInputData(poly_data)
        else:
            writer.SetInput(poly_data)
        writer.Write()


    def save_vtk_labels(filename, tracts, scalars, lines_indices=None):
        lengths = [len(p) for p in tracts]
        line_starts = ns.numpy.r_[0, ns.numpy.cumsum(lengths)]
        if lines_indices is None:
            lines_indices = [
                ns.numpy.arange(length) + line_start
                for length, line_start in izip(lengths, line_starts)
            ]

        ids = ns.numpy.hstack(
            [ns.numpy.r_[c[0], c[1]] for c in izip(lengths, lines_indices)]
        )
        vtk_ids = ns.numpy_to_vtkIdTypeArray(ids, deep=True)

        cell_array = vtk.vtkCellArray()
        cell_array.SetCells(len(tracts), vtk_ids)
        points = ns.numpy.vstack(tracts).astype(
            ns.get_vtk_to_numpy_typemap()[vtk.VTK_DOUBLE]
        )
        points_array = ns.numpy_to_vtk(points, deep=True)

        poly_data = vtk.vtkPolyData()
        vtk_points = vtk.vtkPoints()
        vtk_points.SetData(points_array)
        poly_data.SetPoints(vtk_points)
        poly_data.SetLines(cell_array)
        poly_data.GetPointData().SetScalars(ns.numpy_to_vtk(scalars))
        poly_data.BuildCells()
        #    poly_data.SetScalars(scalars)

        if filename.endswith(".xml") or filename.endswith(".vtp"):
            writer = vtk.vtkXMLPolyDataWriter()
            writer.SetDataModeToBinary()
        else:
            writer = vtk.vtkPolyDataWriter()
            writer.SetFileTypeToBinary()

        writer.SetFileName(filename)
        if hasattr(vtk, "VTK_MAJOR_VERSION") and vtk.VTK_MAJOR_VERSION > 5:
            writer.SetInputData(poly_data)
        else:
            writer.SetInput(poly_data)
        writer.Write()


    def streamlines_resample(streamlines, perc=None, npoints=None):
        if perc is not None:
            resampled = [
                set_number_of_points(s, int(len(s) * perc / 100.0)) for s in streamlines
            ]
        else:
            resampled = [set_number_of_points(s, npoints) for s in streamlines]

        return resampled


    def check_ext(value):
        filename, file_extension = os.path.splitext(value)
        if file_extension in (".vtk", ".xml", ".vtp"):
            return value
        else:
            raise argparse.ArgumentTypeError(
                "Invalid file extension (file format supported: vtk,xml,vtp): %r" % value
            )


    def check_resample(value):
        try:
            t = float(value)
            if 0 <= t < 100:
                return value
            else:
                raise argparse.ArgumentTypeError(
                    "Invalid resampling (must be between 0 and 100): %r" % value
                )
        except ValueError:
            raise argparse.ArgumentTypeError("Invalid resampling value: %r" % value)


    def setup():
        parser = argparse.ArgumentParser()
        parser.add_argument(
            "Input_Tractogram", help="Name of the input tractography file", type=check_ext
        )
        parser.add_argument(
            "Percentage", help="Resampling percentage value", type=check_resample
        )

        args = parser.parse_args()

        return args.Input_Tractogram, float(args.Percentage)


    if __name__ == "__main__":
        in_file, perc = setup()

        streamlines = read_vtk(in_file)[0]
        resampled = streamlines_resample(streamlines, perc)
        file_name, file_extension = os.path.splitext(in_file)
        print(file_name + "_resampled" + file_extension)
        save_vtk(file_name + "_resampled" + file_extension, np.array(resampled))

        sys.exit()


    def save_tract(x, fname, NPOINTS=20):
        tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
        save_vtk(fname, tract.detach().cpu().numpy())


    def save_tract_numpy(x, fname, NPOINTS=20):
        tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
        np.save(
            fname,
            np.float16(tract.detach().cpu().numpy()),
            allow_pickle=False,
            fix_imports=False,
        )


    def save_tract_with_labels(fname, x, scalars, subsampling_fibers=None, NPOINTS=20):
        # save tracts +label information on each fiber
        tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
        tract = tract[0::subsampling_fibers, :, :]
        nf, ns, d = tract.shape
        labels = scalars.view(-1, 1).repeat(1, ns)
        save_vtk_labels(
            fname,
            tract.detach().cpu().numpy(),
            labels.view(-1).detach().cpu().numpy().astype(int),
        )


    def save_tracts_labels_separate(fname, x, labels, start, end, NPOINTS=20):
        # save tracts +label information on each fiber
        tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
        for l in range(start, end):
            if (labels == l).nonzero().shape[0] != 0:
                tract_l = tract[(labels == l).nonzero().view(-1), :, :]
                save_vtk(fname + "_{:05d}.vtk".format(l), tract_l.detach().cpu().numpy())
            else:
                save_vtk(fname + "_{:05d}.vtk".format(l), np.array([[0, 0, 0]]))


.. _sphx_glr_download__auto_examples_brain_tractograms_tract_io.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: tract_io.ipynb <tract_io.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: tract_io.py <tract_io.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_