microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
leca/pagedAttention

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/kernel_context.h

37lines · modecode

1#pragma once
2#include <optional>
3#include <numeric>
4#include <type_traits>
5#include "onnxruntime_c_api.h"
6
7namespace Ort {
8namespace Custom {
9
10// this is for the ORT custom op template magic
11struct Arg {
12 virtual ~Arg() = default;
13};
14
15class KernelContext : public Arg{
16public:
17 virtual void* AllocScratchBuffer(size_t size) = 0;
18 virtual void FreeScratchBuffer(void* p) = 0;
19 // TODO: threadpool?
20};
21
22#ifdef USE_CUDA
23class CUDAKernelContext : public KernelContext {
24public:
25 virtual void* AllocCudaScratchBuffer(size_t size) = 0;
26 virtual void FreeCudaScratchBuffer(void* p) = 0;
27 virtual void* GetCudaStream() const = 0;
28 virtual void* GetCublasHandle() const = 0;
29 virtual int GetCudaDeviceId() const = 0;
30 virtual void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* , size_t ) { return nullptr; }
31};
32#endif
33
34// TODO: helper func to create context from global ORT env.
35
36}
37}
38