1/* 2Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package internal 18 19/* 20#include <stdlib.h> 21#include <string.h> 22 23#include "tensorflow/c/c_api.h" 24*/ 25import "C" 26 27import ( 28 "errors" 29 "fmt" 30 "runtime" 31 "unsafe" 32 33 "github.com/golang/protobuf/proto" 34 pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework" 35) 36 37// Encapsulates a collection of API definitions. 38// 39// apiDefMap represents a map from operation name to corresponding 40// ApiDef proto (see 41// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto 42// for ApiDef proto definition). 43type apiDefMap struct { 44 c *C.TF_ApiDefMap 45} 46 47// Creates and returns a new apiDefMap instance. 48// 49// oplist is and OpList proto instance (see 50// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto 51// for OpList proto definition). 52 53func newAPIDefMap(oplist *pb.OpList) (*apiDefMap, error) { 54 // Create a buffer containing the serialized OpList. 55 opdefSerialized, err := proto.Marshal(oplist) 56 if err != nil { 57 return nil, fmt.Errorf("could not serialize OpDef for %s", oplist.String()) 58 } 59 data := C.CBytes(opdefSerialized) 60 defer C.free(data) 61 62 opbuf := C.TF_NewBuffer() 63 defer C.TF_DeleteBuffer(opbuf) 64 opbuf.data = data 65 opbuf.length = C.size_t(len(opdefSerialized)) 66 67 // Create ApiDefMap. 68 status := C.TF_NewStatus() 69 defer C.TF_DeleteStatus(status) 70 capimap := C.TF_NewApiDefMap(opbuf, status) 71 if C.TF_GetCode(status) != C.TF_OK { 72 return nil, errors.New(C.GoString(C.TF_Message(status))) 73 } 74 apimap := &apiDefMap{capimap} 75 runtime.SetFinalizer( 76 apimap, 77 func(a *apiDefMap) { 78 C.TF_DeleteApiDefMap(a.c) 79 }) 80 return apimap, nil 81} 82 83// Updates apiDefMap with the overrides specified in `data`. 84// 85// data - ApiDef text proto. 86func (m *apiDefMap) Put(data string) error { 87 cdata := C.CString(data) 88 defer C.free(unsafe.Pointer(cdata)) 89 status := C.TF_NewStatus() 90 defer C.TF_DeleteStatus(status) 91 C.TF_ApiDefMapPut(m.c, cdata, C.size_t(len(data)), status) 92 if C.TF_GetCode(status) != C.TF_OK { 93 return errors.New(C.GoString(C.TF_Message(status))) 94 } 95 return nil 96} 97 98// Returns ApiDef proto instance for the TensorFlow operation 99// named `opname`. 100func (m *apiDefMap) Get(opname string) (*pb.ApiDef, error) { 101 cname := C.CString(opname) 102 defer C.free(unsafe.Pointer(cname)) 103 status := C.TF_NewStatus() 104 defer C.TF_DeleteStatus(status) 105 apidefBuf := C.TF_ApiDefMapGet( 106 m.c, cname, C.size_t(len(opname)), status) 107 defer C.TF_DeleteBuffer(apidefBuf) 108 if C.TF_GetCode(status) != C.TF_OK { 109 return nil, errors.New(C.GoString(C.TF_Message(status))) 110 } 111 if apidefBuf == nil { 112 return nil, fmt.Errorf("could not find ApiDef for %s", opname) 113 } 114 115 var ( 116 apidef = new(pb.ApiDef) 117 size = int(apidefBuf.length) 118 // A []byte backed by C memory. 119 // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices 120 data = (*[1 << 30]byte)(unsafe.Pointer(apidefBuf.data))[:size:size] 121 err = proto.Unmarshal(data, apidef) 122 ) 123 if err != nil { 124 return nil, err 125 } 126 return apidef, nil 127} 128