Cómo implementar operaciones personalizadas en PyTorch con C++ y CUDA

Fuentes: Implementing custom PyTorch operations in C++ and CUDA
Imagen generada por IA con el prompt: Abstract digital illustration of glowing neural network layers connected by flowing data streams, gradient background in deep blue and amber, clean editorial tech style
Imagen generada con IA

Esta entrada del blog explica cómo implementar operaciones personalizadas en PyTorch utilizando C++ y CUDA, y cómo integrarlas tanto en modelos de PyTorch como en programas de inferencia compilados con AOTInductor. El ejemplo guía es una convolución identidad mínima que ilustra el ciclo completo: desde la definición en C++/CUDA hasta su uso en Python y su exportación con torch.export.

La primera parte aborda las funciones personalizadas sin estado, registradas mediante la macro TORCH_LIBRARY_IMPL. El autor muestra cómo proporcionar implementaciones separadas para CPU (un clone() elemento a elemento) y CUDA (un kernel que respeta la forma, el tipo de datos y los strides del tensor de entrada). La macro gestiona el envío automático a la implementación adecuada según el dispositivo del tensor. Esquema, kernel CUDA y registro de la operación conviven en un mismo archivo, lo que facilita la compilación en una biblioteca compartida (libidentity_conv_ops.so).

La segunda parte cubre las clases personalizadas con estado —capaces de almacenar parámetros— mediante torch::CustomClassHolder y la macro TORCH_LIBRARY. La clase IdentityConvClass se expone a Python como torch.classes.my_ops.IdentityConvClass, con constructor, método forward, descriptor channels, el protocolo de pytree (obj_flatten/obj_unflatten) que requiere torch.export y serialización TorchScript mediante def_pickle. La biblioteca se compila sin dependencia de pybind11, de modo que un binario C++ puro puede cargarla vía dlopen.

La tercera parte describe la integración en Python: la biblioteca se carga con torch.ops.load_library y, para que torch.compile y torch.export operen con FakeTensor, se registran versiones abstractas o “fake” mediante @register_fake_class y @torch.library.register_fake. Se detallan los métodos obj_flatten y obj_unflatten que la clase falsa debe implementar para que el trazado simbólico recorra correctamente los atributos del módulo que contienen una instancia de la clase C++.

El artículo está dirigido a desarrolladores que necesitan extender PyTorch con kernels CUDA optimizados, ya sea para acelerar cargas de trabajo concretas o para desplegar modelos compilados con AOTInductor. No trata la optimización interna del kernel ni la depuración, sino la infraestructura de registro y la compatibilidad con el resto de la cadena de herramientas.