.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "_auto_examples/performances/plot_profile.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download__auto_examples_performances_plot_profile.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr__auto_examples_performances_plot_profile.py:


Profile the GeomLoss routines
===================================

This example explains how to **profile** the geometric losses
to select the backend and truncation/scaling values that
are best suited to your data.

.. GENERATED FROM PYTHON SOURCE LINES 12-14

Setup
---------------------

.. GENERATED FROM PYTHON SOURCE LINES 14-22

.. code-block:: Python


    import torch
    from geomloss import SamplesLoss
    from time import time

    use_cuda = torch.cuda.is_available()
    dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor








.. GENERATED FROM PYTHON SOURCE LINES 23-25

Sample points on the unit sphere:


.. GENERATED FROM PYTHON SOURCE LINES 25-31

.. code-block:: Python


    N, M = (100, 100) if not use_cuda else (100000, 100000)
    x, y = torch.randn(N, 3).type(dtype), torch.randn(M, 3).type(dtype)
    x, y = x / (2 * x.norm(dim=1, keepdim=True)), y / (2 * y.norm(dim=1, keepdim=True))
    x.requires_grad = True








.. GENERATED FROM PYTHON SOURCE LINES 32-33

Use the PyTorch profiler to output Chrome trace files:

.. GENERATED FROM PYTHON SOURCE LINES 33-49

.. code-block:: Python


    for loss in ["gaussian", "sinkhorn"]:
        for backend in ["online", "multiscale"]:
            with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof:
                Loss = SamplesLoss(
                    loss, blur=0.05, backend=backend, truncate=3, verbose=True
                )
                t_0 = time()
                L_xy = Loss(x, y)
                L_xy.backward()
                t_1 = time()
                print("{:.2f}s, cost = {:.6f}".format(t_1 - t_0, L_xy.item()))

            prof.export_chrome_trace("output/profile_" + loss + "_" + backend + ".json")






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    0.08s, cost = 0.000010
    707x711 clusters, computed at scale = 1.587
    0.04s, cost = 0.000010
    0.68s, cost = 0.000011
    707x711 clusters, computed at scale = 0.079
    Successive scales :  1.732, 1.732, 0.866, 0.433, 0.217, 0.108, 0.054, 0.050
    Jump from coarse to fine between indices 5 (σ=0.108) and 6 (σ=0.054).
    Keep 79729/502677 = 15.9% of the coarse cost matrix.
    Keep 79317/499849 = 15.9% of the coarse cost matrix.
    Keep 80193/505521 = 15.9% of the coarse cost matrix.
    0.09s, cost = 0.000012




.. GENERATED FROM PYTHON SOURCE LINES 50-54

Now, all you have to do is to open the "Easter egg" address
``chrome://tracing`` in Google Chrome/Chromium,
and load the ``profile_*`` files one after
another. Enjoy :-)

.. GENERATED FROM PYTHON SOURCE LINES 54-56

.. code-block:: Python


    print("Done.")




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Done.





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 1.039 seconds)


.. _sphx_glr_download__auto_examples_performances_plot_profile.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_profile.ipynb <plot_profile.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_profile.py <plot_profile.py>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_