K-NN classification - PyTorch API

The argKmin(K) reduction supported by KeOps pykeops.torch.LazyTensor allows us to perform bruteforce k-nearest neighbors search with four lines of code. It can thus be used to implement a large-scale K-NN classifier, without memory overflows.

Setup

Standard imports:

import time

import numpy as np
import torch
from matplotlib import pyplot as plt

from pykeops.torch import LazyTensor

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

Dataset, in 2D:

N, D = 10000 if use_cuda else 1000, 2  # Number of samples, dimension
x = torch.rand(N, D).type(dtype)  # Random samples on the unit square


# Random-ish class labels:
def fth(x):
    return 3 * x * (x - 0.5) * (x - 1) + x


cl = x[:, 1] + 0.1 * torch.randn(N).type(dtype) < fth(x[:, 0])

Reference sampling grid, on the unit square:

M = 1000 if use_cuda else 100
tmp = torch.linspace(0, 1, M).type(dtype)
g2, g1 = torch.meshgrid(tmp, tmp)
g = torch.cat((g1.contiguous().view(-1, 1), g2.contiguous().view(-1, 1)), dim=1)