• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from pathlib import Path
9
10import torch
11
12try:
13    tile_crop = torch.ops.preprocess.tile_crop.default
14    assert tile_crop is not None
15except:
16    libs = list(Path(__file__).parent.resolve().glob("libcustom_ops_aot_lib.*"))
17    assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
18    logging.info(f"Loading custom ops library: {libs[0]}")
19    torch.ops.load_library(libs[0])
20    tile_crop = torch.ops.preprocess.tile_crop.default
21    assert tile_crop is not None
22
23preprocess_ops_lib = torch.library.Library("preprocess", "IMPL")
24
25MAX_NUM_TILES = 4
26
27
28# Register meta kernel to prevent export tracing into the tile_crop impl.
29@torch.library.register_fake("preprocess::tile_crop")
30def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor:
31    # Returned tensor is of size [n, 3, 224, 224], where n = number of tiles.
32    # Use an unbacked symint to create an upper-bounded dynamic shape output.
33    # Otherwise, output is set to a static shape, and we can only output
34    # tensors of shape [MAX_NUM_TILES, 3, 224, 224].
35    ctx = torch._custom_ops.get_ctx()
36    s0 = ctx.create_unbacked_symint()
37    torch._constrain_as_size(s0, 0, MAX_NUM_TILES)
38    return torch.empty([s0, output.size(0), tile_size, tile_size])
39