1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from pathlib import Path 9 10import torch 11 12try: 13 tile_crop = torch.ops.preprocess.tile_crop.default 14 assert tile_crop is not None 15except: 16 libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*")) 17 assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" 18 logging.info(f"Loading custom ops library: {libs[0]}") 19 torch.ops.load_library(libs[0]) 20 tile_crop = torch.ops.preprocess.tile_crop.default 21 assert tile_crop is not None 22 23preprocess_ops_lib = torch.library.Library("preprocess", "IMPL") 24 25MAX_NUM_TILES = 4 26 27 28# Register meta kernel to prevent export tracing into the tile_crop impl. 29@torch.library.register_fake("preprocess::tile_crop") 30def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor: 31 # Returned tensor is of size [n, 3, 224, 224], where n = number of tiles. 32 # Use an unbacked symint to create an upper-bounded dynamic shape output. 33 # Otherwise, output is set to a static shape, and we can only output 34 # tensors of shape [MAX_NUM_TILES, 3, 224, 224]. 35 ctx = torch._custom_ops.get_ctx() 36 s0 = ctx.create_unbacked_symint() 37 torch._constrain_as_size(s0, 0, MAX_NUM_TILES) 38 return torch.empty([s0, output.size(0), tile_size, tile_size]) 39